Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace StudySummary to FrozenStudy in serializing #809

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
from ._preferential_history import report_history
from ._preferential_history import restore_history
from ._rdb_migration import register_rdb_migration_route
from ._serializer import serialize_frozen_study
from ._serializer import serialize_study_detail
from ._serializer import serialize_study_summary
from ._storage import create_new_study
from ._storage import get_study_summaries
from ._storage import get_study_summary
from ._storage import get_studies
from ._storage import get_study
from ._storage import get_trials
from ._storage_url import get_storage
from .artifact._backend import delete_all_artifacts
Expand Down Expand Up @@ -101,9 +101,10 @@ def api_meta() -> dict[str, Any]:

@app.get("/api/studies")
@json_api_view
def list_study_summaries() -> dict[str, Any]:
summaries = get_study_summaries(storage)
serialized = [serialize_study_summary(summary) for summary in summaries]
def list_studies() -> dict[str, Any]:
studies = get_studies(storage)
serialized = [serialize_frozen_study(s) for s in studies]
# TODO(umezawa): Rename `study_summaries` to `studies`.
return {
"study_summaries": serialized,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I change the key name of study_summaries and study_summary to something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a todo comment? For now, we cannot introduce the breaking changes in JSON Web APIs since our Jupyter Lab extension uses it.

}
Expand Down Expand Up @@ -131,12 +132,12 @@ def create_study() -> dict[str, Any]:
response.status = 400 # Bad request
return {"reason": f"'{study_name}' already exists"}

summary = get_study_summary(storage, study_id)
if summary is None:
study = get_study(storage, study_id)
if study is None:
response.status = 500 # Internal server error
return {"reason": "Failed to create study"}
response.status = 201 # Created
return {"study_summary": serialize_study_summary(summary)}
return {"study_summary": serialize_frozen_study(study)}

@app.post("/api/studies/<study_id:int>/rename")
@json_api_view
Expand Down Expand Up @@ -167,14 +168,14 @@ def rename_study(study_id: int) -> dict[str, Any]:
response.status = 500
storage.delete_study(dst_study._study_id)
return {"reason": str(e)}
new_study_summary = get_study_summary(storage, dst_study._study_id)
if new_study_summary is None:
new_study = get_study(storage, dst_study._study_id)
if new_study is None:
response.status = 500
return {"reason": "Failed to load the new study"}

storage.delete_study(src_study._study_id)
response.status = 201
return serialize_study_summary(new_study_summary)
return serialize_frozen_study(new_study)

@app.delete("/api/studies/<study_id:int>")
@json_api_view
Expand All @@ -201,24 +202,24 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
return {"reason": "`after` should be larger or equal 0."}
except KeyError:
after = 0
summary = get_study_summary(storage, study_id)
if summary is None:
study = get_study(storage, study_id)
if study is None:
response.status = 404 # Not found
return {"reason": f"study_id={study_id} is not found"}
trials = get_trials(storage, study_id)

system_attrs = getattr(summary, "system_attrs", {})
system_attrs = getattr(study, "system_attrs", {})
is_preferential = system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False)
# TODO(c-bata): Cache best_trials
if is_preferential:
best_trials = get_best_preferential_trials(study_id, storage)
elif len(summary.directions) == 1:
elif len(study.directions) == 1:
if len([t for t in trials if t.state == TrialState.COMPLETE]) == 0:
best_trials = []
else:
best_trials = [storage.get_best_trial(study_id)]
else:
best_trials = get_pareto_front_trials(trials=trials, directions=summary.directions)
best_trials = get_pareto_front_trials(trials=trials, directions=study.directions)
(
# TODO: intersection_search_space and union_search_space look more clear since now we
# have union_user_attrs.
Expand All @@ -232,7 +233,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
skipped_trial_ids = get_skipped_trial_ids(system_attrs)
skipped_trial_numbers = [t.number for t in trials if t._trial_id in skipped_trial_ids]
return serialize_study_detail(
summary,
study,
best_trials,
trials[after:],
intersection,
Expand Down
35 changes: 14 additions & 21 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
from optuna.distributions import IntDistribution
from optuna.study import StudySummary
from optuna.study._frozen import FrozenStudy
from optuna.trial import FrozenTrial

from . import _note as note
Expand Down Expand Up @@ -116,25 +116,20 @@ def serialize_attrs(attrs: dict[str, Any]) -> list[Attribute]:
return serialized


def serialize_study_summary(summary: StudySummary) -> dict[str, Any]:
def serialize_frozen_study(study: FrozenStudy) -> dict[str, Any]:
serialized = {
"study_id": summary._study_id,
"study_name": summary.study_name,
"directions": [d.name.lower() for d in summary.directions],
"user_attrs": serialize_attrs(summary.user_attrs),
"is_preferential": getattr(summary, "_system_attrs", {}).get(
_SYSTEM_ATTR_PREFERENTIAL_STUDY, False
),
"study_id": study._study_id,
"study_name": study.study_name,
"directions": [d.name.lower() for d in study.directions],
"user_attrs": serialize_attrs(study.user_attrs),
"is_preferential": study.system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False),
}

if summary.datetime_start is not None:
serialized["datetime_start"] = summary.datetime_start.isoformat()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add datetime_start, I need to add this attribute in FrozenStudy.

Copy link
Member

@c-bata c-bata Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say we should drop datetime_start for performance reasons. Could you remove a following line as well?


return serialized


def serialize_study_detail(
summary: StudySummary,
study: FrozenStudy,
best_trials: list[FrozenTrial],
trials: list[FrozenTrial],
intersection: list[tuple[str, BaseDistribution]],
Expand All @@ -145,20 +140,18 @@ def serialize_study_detail(
skipped_trial_numbers: list[int],
) -> dict[str, Any]:
serialized: dict[str, Any] = {
"name": summary.study_name,
"directions": [d.name.lower() for d in summary.directions],
"user_attrs": serialize_attrs(summary.user_attrs),
"name": study.study_name,
"directions": [d.name.lower() for d in study.directions],
"user_attrs": serialize_attrs(study.user_attrs),
}
system_attrs = getattr(summary, "system_attrs", {})
system_attrs = study.system_attrs
serialized["artifacts"] = list_study_artifacts(system_attrs)
if summary.datetime_start is not None:
serialized["datetime_start"] = summary.datetime_start.isoformat()

serialized["trials"] = [
serialize_frozen_trial(summary._study_id, trial, system_attrs) for trial in trials
serialize_frozen_trial(study._study_id, trial, system_attrs) for trial in trials
]
serialized["best_trials"] = [
serialize_frozen_trial(summary._study_id, trial, system_attrs) for trial in best_trials
serialize_frozen_trial(study._study_id, trial, system_attrs) for trial in best_trials
]
serialized["intersection_search_space"] = serialize_search_space(intersection)
serialized["union_search_space"] = serialize_search_space(union)
Expand Down
36 changes: 8 additions & 28 deletions optuna_dashboard/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,14 @@
from datetime import datetime
from datetime import timedelta
import threading
import typing

from optuna.storages import BaseStorage
from optuna.storages import RDBStorage
from optuna.study import StudyDirection
from optuna.study import StudySummary
from optuna.study._frozen import FrozenStudy
from optuna.trial import FrozenTrial


if typing.TYPE_CHECKING:
from optuna.study._frozen import FrozenStudy


# In-memory trials cache
trials_cache_lock = threading.Lock()
trials_cache: dict[int, list[FrozenTrial]] = {}
Expand Down Expand Up @@ -49,19 +44,19 @@ def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]:
return trials


def get_study_summaries(storage: BaseStorage) -> list[StudySummary]:
def get_studies(storage: BaseStorage) -> list[FrozenStudy]:
frozen_studies = storage.get_all_studies()
if isinstance(storage, RDBStorage):
frozen_studies = sorted(frozen_studies, key=lambda s: s._study_id)
return [_frozen_study_to_study_summary(s) for s in frozen_studies]
return frozen_studies


def get_study_summary(storage: BaseStorage, study_id: int) -> StudySummary | None:
summaries = get_study_summaries(storage)
for summary in summaries:
if summary._study_id != study_id:
def get_study(storage: BaseStorage, study_id: int) -> FrozenStudy | None:
studies = get_studies(storage)
for s in studies:
if s._study_id != study_id:
continue
return summary
return s
return None


Expand All @@ -70,18 +65,3 @@ def create_new_study(
) -> int:
study_id = storage.create_new_study(directions, study_name=study_name)
return study_id


def _frozen_study_to_study_summary(frozen_study: "FrozenStudy") -> StudySummary:
is_single = len(frozen_study.directions) == 1
return StudySummary(
study_name=frozen_study.study_name,
study_id=frozen_study._study_id,
direction=frozen_study.direction if is_single else None,
directions=frozen_study.directions if not is_single else None,
user_attrs=frozen_study.user_attrs,
system_attrs=frozen_study.system_attrs,
best_trial=None,
n_trials=-1, # This field isn't used by Dashboard.
datetime_start=None,
)
42 changes: 18 additions & 24 deletions python_tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import optuna
from optuna_dashboard._serializer import serialize_attrs
from optuna_dashboard._serializer import serialize_frozen_study
from optuna_dashboard._serializer import serialize_study_detail
from optuna_dashboard._serializer import serialize_study_summary
from optuna_dashboard._storage import get_study_summaries
from optuna_dashboard._storage import get_studies
from optuna_dashboard.preferential import create_study
from packaging import version
import pytest
Expand Down Expand Up @@ -60,47 +60,41 @@ def test_serialize_numpy_floating() -> None:
def test_get_study_detail_is_preferential() -> None:
storage = optuna.storages.InMemoryStorage()
study = create_study(n_generate=4, storage=storage)
study_summaries = get_study_summaries(storage)
assert len(study_summaries) == 1
studies = get_studies(storage)
assert len(studies) == 1

study_summary = study_summaries[0]
study_detail = serialize_study_detail(
study_summary, [], study.trials, [], [], [], False, {}, []
)
study_detail = serialize_study_detail(studies[0], [], study.trials, [], [], [], False, {}, [])
assert study_detail["is_preferential"]


def test_get_study_detail_is_not_preferential() -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
study_summaries = get_study_summaries(storage)
assert len(study_summaries) == 1
studies = get_studies(storage)
assert len(studies) == 1

study_summary = study_summaries[0]
study_detail = serialize_study_detail(
study_summary, [], study.trials, [], [], [], False, {}, []
)
study_detail = serialize_study_detail(studies[0], [], study.trials, [], [], [], False, {}, [])
assert not study_detail["is_preferential"]


@pytest.mark.skipif(sys.version_info < (3, 8), reason="BoTorch dropped Python3.7 support")
@pytest.mark.skipif(
version.parse(optuna.__version__) < version.parse("3.2.0"), reason="Needs optuna.search_space"
)
def test_get_study_summary_is_preferential() -> None:
def test_get_study_is_preferential() -> None:
storage = optuna.storages.InMemoryStorage()
create_study(n_generate=4, storage=storage)
study_summaries = get_study_summaries(storage)
assert len(study_summaries) == 1
studies = get_studies(storage)
assert len(studies) == 1

study_summary = serialize_study_summary(study_summaries[0])
assert study_summary["is_preferential"]
serialized = serialize_frozen_study(studies[0])
assert serialized["is_preferential"]


def test_get_study_summary_is_not_preferential() -> None:
def test_get_study_is_not_preferential() -> None:
storage = optuna.storages.InMemoryStorage()
optuna.create_study(storage=storage)
study_summaries = get_study_summaries(storage)
assert len(study_summaries) == 1
study_summary = serialize_study_summary(study_summaries[0])
assert not study_summary["is_preferential"]
studies = get_studies(storage)
assert len(studies) == 1
serialized = serialize_frozen_study(studies[0])
assert not serialized["is_preferential"]
Loading