Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code fix/download csv #723

Merged
merged 34 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
679646b
add csv download
Sep 30, 2023
c6f495d
add csv download button
Oct 5, 2023
44ad948
remove branks in L484
RuTiO2le Oct 5, 2023
d0be784
Merge pull request #1 from optuna/main
eukaryo Dec 6, 2023
547d985
Update _app.py
eukaryo Dec 6, 2023
65d7c8e
add download button
eukaryo Dec 6, 2023
46ec277
fix lint
eukaryo Dec 6, 2023
6f681c4
Update optuna_dashboard/_app.py
eukaryo Dec 6, 2023
3b31422
Update optuna_dashboard/_app.py
eukaryo Dec 6, 2023
a5502ce
Update _app.py
eukaryo Dec 6, 2023
c73abac
Add files via upload
eukaryo Dec 6, 2023
26d29f3
fix mypy
eukaryo Dec 6, 2023
895fc0b
Update optuna_dashboard/_app.py
eukaryo Dec 6, 2023
908788d
Update optuna_dashboard/_app.py
eukaryo Dec 6, 2023
2a03cd5
Update optuna_dashboard/_app.py
eukaryo Dec 6, 2023
8eac8fc
add tests
eukaryo Dec 6, 2023
2fde74b
Update optuna_dashboard/_app.py
eukaryo Dec 6, 2023
3d6e69a
revised _app.py
eukaryo Dec 6, 2023
e5e813b
fix lint
eukaryo Dec 6, 2023
f57bab2
revised _app.py
eukaryo Dec 6, 2023
c9915ad
Update optuna_dashboard/_app.py
eukaryo Dec 6, 2023
7d834d3
Update StudyDetail.tsx
eukaryo Dec 6, 2023
64dc46c
Update python_tests/test_csv_download.py
eukaryo Dec 7, 2023
462bc99
Update python_tests/test_csv_download.py
eukaryo Dec 7, 2023
678b15e
Update python_tests/test_csv_download.py
eukaryo Dec 7, 2023
18d68ea
Update test_csv_download.py
eukaryo Dec 7, 2023
4a3849b
Update StudyDetail.tsx
eukaryo Dec 7, 2023
6d44225
Update StudyDetail.tsx
eukaryo Dec 7, 2023
e1f43e1
Update python_tests/test_csv_download.py
eukaryo Dec 7, 2023
a4fb8fd
Update python_tests/test_csv_download.py
eukaryo Dec 7, 2023
1e5d582
Update python_tests/test_csv_download.py
eukaryo Dec 7, 2023
d7ff6f4
Update python_tests/test_csv_download.py
eukaryo Dec 7, 2023
a8c451a
Update python_tests/test_csv_download.py
eukaryo Dec 7, 2023
f63b5e9
fix lint
eukaryo Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

import csv
import functools
import io
from itertools import chain
import logging
import os
import re
import typing
from typing import Any
from typing import Optional
Expand Down Expand Up @@ -152,7 +156,7 @@
storage=storage, study_name=dst_study_name, directions=src_study.directions
)
dst_study.add_trials(src_study.get_trials(deepcopy=False))
note.copy_notes(storage, src_study, dst_study)

Check warning on line 159 in optuna_dashboard/_app.py

View check run for this annotation

Codecov / codecov/patch

optuna_dashboard/_app.py#L159

Added line #L159 was not covered by tests
except DuplicatedStudyError:
response.status = 400 # Bad request
return {"reason": f"study_name={dst_study_name} is duplicaated"}
Expand Down Expand Up @@ -449,6 +453,48 @@
response.status = 204 # No content
return {}

@app.get("/csv/<study_id:int>")
def download_csv(study_id: int) -> BottleViewReturn:
# Create a CSV file
try:
study_name = storage.get_study_name_from_id(study_id)
study = optuna.load_study(storage=storage, study_name=study_name)
except KeyError:
response.status = 404 # Not found
return {"reason": f"study_id={study_id} is not found"}
trials = study.trials
param_names = sorted(set(chain.from_iterable([t.params.keys() for t in trials])))
user_attr_names = sorted(set(chain.from_iterable([t.user_attrs.keys() for t in trials])))
param_names_header = [f"Param {x}" for x in param_names]
user_attr_names_header = [f"UserAttribute {x}" for x in user_attr_names]
n_objs = len(study.directions)
if study.metric_names is not None:
value_header = study.metric_names

Check warning on line 472 in optuna_dashboard/_app.py

View check run for this annotation

Codecov / codecov/patch

optuna_dashboard/_app.py#L472

Added line #L472 was not covered by tests
else:
value_header = ["Value"] if n_objs == 1 else [f"Objective {x}" for x in range(n_objs)]
column_names = (
["Number", "State"] + value_header + param_names_header + user_attr_names_header
)

buf = io.StringIO("")
writer = csv.writer(buf)
writer.writerow(column_names)
for frozen_trial in trials:
row = [frozen_trial.number, frozen_trial.state.name]
row.extend(frozen_trial.values if frozen_trial.values is not None else [None] * n_objs)
row.extend([frozen_trial.params.get(name, None) for name in param_names])
row.extend([frozen_trial.user_attrs.get(name, None) for name in user_attr_names])
writer.writerow(row)

# Set response headers
output_name = "-".join(re.sub(r'[\\/:*?"<>|]+', "", study_name).split(" "))
response.headers["Content-Type"] = "text/csv; chatset=cp932"
response.headers["Content-Disposition"] = f"attachment; filename={output_name}.csv"

# Response body
buf.seek(0)
return buf.read()

@app.get("/favicon.ico")
def favicon() -> BottleViewReturn:
use_gzip = "gzip" in request.headers["Accept-Encoding"]
Expand Down
39 changes: 34 additions & 5 deletions optuna_dashboard/ts/components/StudyDetail.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
import Grid2 from "@mui/material/Unstable_Grid2"
import ChevronRightIcon from "@mui/icons-material/ChevronRight"
import HomeIcon from "@mui/icons-material/Home"
import DownloadIcon from "@mui/icons-material/Download"

import { StudyNote } from "./Note"
import { actionCreator } from "../action"
Expand Down Expand Up @@ -149,11 +150,39 @@ export const StudyDetail: FC<{
content = <TrialList studyDetail={studyDetail} />
} else if (page === "trialTable") {
content = (
<Card sx={{ margin: theme.spacing(2) }}>
<CardContent>
<TrialTable studyDetail={studyDetail} initialRowsPerPage={50} />
</CardContent>
</Card>
<Box sx={{ display: "flex", width: "100%", flexDirection: "column" }}>
<Card
sx={{
margin: theme.spacing(2),
width: "auto",
height: "auto",
display: "flex",
justifyContent: "left",
alignItems: "left",
}}
>
<CardContent>
<IconButton
aria-label="download csv"
size="small"
color="inherit"
download
sx={{ margin: "auto 0" }}
href={`/csv/${studyDetail?.id}`}
>
<DownloadIcon />
<Typography variant="button" sx={{ margin: theme.spacing(2) }}>
Download CSV File
</Typography>
</IconButton>
</CardContent>
</Card>
<Card sx={{ margin: theme.spacing(2) }}>
<CardContent>
<TrialTable studyDetail={studyDetail} initialRowsPerPage={50} />
</CardContent>
</Card>
</Box>
)
} else if (page === "note" && studyDetail !== null) {
content = (
Expand Down
109 changes: 109 additions & 0 deletions python_tests/test_csv_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

from typing import Any
eukaryo marked this conversation as resolved.
Show resolved Hide resolved

import optuna
from optuna.trial import TrialState
from optuna_dashboard._app import create_app
import pytest

from .wsgi_client import send_request


def _validate_output(
storage: optuna.storages.BaseStorage,
correct_status: int,
study_id: int,
expect_no_result: bool = False,
extra_col_names: list[str] | None = None,
) -> None:
app = create_app(storage)
status, _, body = send_request(
app,
f"/csv/{study_id}",
"GET",
content_type="application/json",
)
assert status == correct_status
decoded_csv = str(body.decode("utf-8"))
if expect_no_result:
assert "is not found" in decoded_csv
else:
col_names = ["Number", "State"] + ([] if extra_col_names is None else extra_col_names)
assert all(col_name in decoded_csv for col_name in col_names)


def test_download_csv_no_trial() -> None:
def objective(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
return x**2 + y

storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
study.optimize(objective, n_trials=0)
_validate_output(storage, 200, 0)


def test_download_csv_all_waiting() -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
study.add_trial(optuna.trial.create_trial(state=TrialState.WAITING))
_validate_output(storage, 200, 0)


def test_download_csv_all_running() -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
study.add_trial(optuna.trial.create_trial(state=TrialState.RUNNING))
_validate_output(storage, 200, 0)


@pytest.mark.parametrize("study_id", [0, 1])
def test_download_csv_fail(study_id: int) -> None:
def objective(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
return x**2 + y

storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
optuna.logging.set_verbosity(optuna.logging.ERROR)
study.optimize(objective, n_trials=10)
expect_no_result = study_id != 0
cols = ["Param x", "Param y", "Value"]
_validate_output(storage, 404 if expect_no_result else 200, study_id, expect_no_result, cols)


@pytest.mark.parametrize("is_multi_obj", [True, False])
def test_download_csv_multi_obj(is_multi_obj: bool) -> None:
def objective(trial: optuna.Trial) -> Any:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
if is_multi_obj:
return x**2, y
return x**2 + y

storage = optuna.storages.InMemoryStorage()
directions = ["minimize", "minimize"] if is_multi_obj else ["minimize"]
study = optuna.create_study(storage=storage, directions=directions)
optuna.logging.set_verbosity(optuna.logging.ERROR)
study.optimize(objective, n_trials=10)
cols = ["Param x", "Param y"]
cols += ["Objective 0", "Objective 1"] if is_multi_obj else ["Value"]
_validate_output(storage, 200, 0, extra_col_names=cols)


def test_download_csv_user_attr() -> None:
def objective(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
trial.set_user_attr("abs_y", abs(y))
return x**2 + y

storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
optuna.logging.set_verbosity(optuna.logging.ERROR)
study.optimize(objective, n_trials=10)
cols = ["Param x", "Param y", "Value", "UserAttribute abs_y"]
_validate_output(storage, 200, 0, extra_col_names=cols)
Loading