diff --git a/optuna_dashboard/_storage_url.py b/optuna_dashboard/_storage_url.py index 351f9f546..e272236ef 100644 --- a/optuna_dashboard/_storage_url.py +++ b/optuna_dashboard/_storage_url.py @@ -1,6 +1,7 @@ from __future__ import annotations import os.path +from pathlib import Path import re from typing import TYPE_CHECKING @@ -58,12 +59,25 @@ def get_storage( return guess_storage_from_url(storage) +def _has_sqlite_header(storage_url: str) -> bool: + storage_path = Path(storage_url) + SQLITE_HEADER = ( + b"SQLite format 3\x00" # see https://github.com/optuna/optuna-dashboard/pull/800 + ) + with storage_path.open(mode="rb") as f: + header = f.read(len(SQLITE_HEADER)) + return header == SQLITE_HEADER + + def guess_storage_from_url(storage_url: str) -> BaseStorage: if storage_url.startswith("redis"): return get_journal_redis_storage(storage_url) if os.path.isfile(storage_url): - return get_journal_file_storage(storage_url) + if _has_sqlite_header(storage_url): + return get_rdb_storage("sqlite:///" + storage_url) + else: + return get_journal_file_storage(storage_url) if rfc1738_pattern.match(storage_url) is not None: return get_rdb_storage(storage_url)