From 28dde7fd428e420c117a3533bb3dc1ad755926ae Mon Sep 17 00:00:00 2001 From: keisuke-umezawa Date: Sun, 7 Jul 2024 15:40:11 +0900 Subject: [PATCH 1/2] Use tslib PlotImportance in optuna-dashboard --- .../GraphHyperparameterImportances.tsx | 129 +++--------------- standalone_app/src/components/StudyDetail.tsx | 4 +- tslib/react/src/components/PlotImportance.tsx | 19 +-- 3 files changed, 31 insertions(+), 121 deletions(-) diff --git a/optuna_dashboard/ts/components/GraphHyperparameterImportances.tsx b/optuna_dashboard/ts/components/GraphHyperparameterImportances.tsx index 33f6a4ea0..6aadbba12 100644 --- a/optuna_dashboard/ts/components/GraphHyperparameterImportances.tsx +++ b/optuna_dashboard/ts/components/GraphHyperparameterImportances.tsx @@ -1,16 +1,13 @@ -import { Box, Card, CardContent, Typography, useTheme } from "@mui/material" +import { Box, Card, CardContent } from "@mui/material" import * as plotly from "plotly.js-dist-min" import React, { FC, useEffect } from "react" -import { ParamImportance, StudyDetail } from "ts/types/optuna" +import { PlotImportance } from "@optuna/react" +import { StudyDetail } from "ts/types/optuna" import { PlotType } from "../apiClient" import { useParamImportance } from "../hooks/useParamImportance" import { usePlot } from "../hooks/usePlot" -import { - useBackendRender, - usePlotlyColorTheme, - useStudyDirections, -} from "../state" +import { useBackendRender } from "../state" const plotDomId = "graph-hyperparameter-importances" @@ -19,6 +16,13 @@ export const GraphHyperparameterImportance: FC<{ study: StudyDetail | null graphHeight: string }> = ({ studyId, study = null, graphHeight }) => { + const numCompletedTrials = + study?.trials.filter((t) => t.state === "Complete").length || 0 + const { importances } = useParamImportance({ + numCompletedTrials, + studyId, + }) + if (useBackendRender()) { return ( + + + + + ) } } @@ -64,100 +72,3 @@ const GraphHyperparameterImportanceBackend: FC<{ return } - -const GraphHyperparameterImportanceFrontend: FC<{ - studyId: number - study: StudyDetail | null - graphHeight: string -}> = ({ studyId, study = null, graphHeight }) => { - const theme = useTheme() - const colorTheme = usePlotlyColorTheme(theme.palette.mode) - - const numCompletedTrials = - study?.trials.filter((t) => t.state === "Complete").length || 0 - const { importances } = useParamImportance({ - numCompletedTrials, - studyId, - }) - const nObjectives = useStudyDirections(studyId)?.length - const objectiveNames: string[] = - study?.objective_names || - study?.directions.map((d, i) => `Objective ${i}`) || - [] - - useEffect(() => { - if (importances !== undefined && nObjectives === importances.length) { - plotParamImportance(importances, objectiveNames, colorTheme) - } - }, [nObjectives, importances, colorTheme]) - - return ( - - - - Hyperparameter Importance - - - - - ) -} - -const plotParamImportance = ( - importances: ParamImportance[][], - objectiveNames: string[], - colorTheme: Partial -) => { - const layout: Partial = { - xaxis: { - title: "Hyperparameter Importance", - }, - yaxis: { - title: "Hyperparameter", - automargin: true, - }, - margin: { - l: 50, - t: 0, - r: 50, - b: 50, - }, - barmode: "group", - bargap: 0.15, - bargroupgap: 0.1, - uirevision: "true", - template: colorTheme, - legend: { - x: 1.0, - y: 0.95, - }, - } - - if (document.getElementById(plotDomId) === null) { - return - } - const traces: Partial[] = importances.map( - (importance, i) => { - const reversed = [...importance].reverse() - const importance_values = reversed.map((p) => p.importance) - const param_names = reversed.map((p) => p.name) - const param_hover_templates = reversed.map( - (p) => `${p.name} (${p.distribution}): ${p.importance} ` - ) - return { - type: "bar", - orientation: "h", - name: objectiveNames[i], - x: importance_values, - y: param_names, - text: importance_values.map((v) => String(v.toFixed(2))), - textposition: "outside", - hovertemplate: param_hover_templates, - } - } - ) - plotly.react(plotDomId, traces, layout) -} diff --git a/standalone_app/src/components/StudyDetail.tsx b/standalone_app/src/components/StudyDetail.tsx index 7d137f5b0..e7647a372 100644 --- a/standalone_app/src/components/StudyDetail.tsx +++ b/standalone_app/src/components/StudyDetail.tsx @@ -180,9 +180,7 @@ export const StudyDetail: FC<{ - {!!study && ( - - )} + diff --git a/tslib/react/src/components/PlotImportance.tsx b/tslib/react/src/components/PlotImportance.tsx index f923e78d6..9636c8164 100644 --- a/tslib/react/src/components/PlotImportance.tsx +++ b/tslib/react/src/components/PlotImportance.tsx @@ -7,19 +7,20 @@ import { plotlyDarkTemplate } from "./PlotlyDarkMode" const plotDomId = "graph-hyperparameter-importances" export const PlotImportance: FC<{ - study: Optuna.Study - importance: Optuna.ParamImportance[][] -}> = ({ study, importance }) => { + study: Optuna.Study | null + importance?: Optuna.ParamImportance[][] + graphHeight?: string +}> = ({ study = null, importance, graphHeight = "450px" }) => { const theme = useTheme() - const objectiveNames: string[] = study.directions.map( - (_d, i) => `Objective ${i}` - ) + const objectiveNames: string[] = study + ? study.directions.map((_d, i) => `Objective ${i}`) + : [] useEffect(() => { - if (importance.length > 0) { + if (study !== null && importance !== undefined && importance.length > 0) { plotParamImportancesBeta(importance, objectiveNames, theme.palette.mode) } - }, [objectiveNames, importance, theme.palette.mode]) + }, [study, objectiveNames, importance, theme.palette.mode]) return ( <> @@ -29,7 +30,7 @@ export const PlotImportance: FC<{ > Hyperparameter Importance - + ) } From 3e353e6ef3059f450301bacd605d91ff423e303a Mon Sep 17 00:00:00 2001 From: keisuke-umezawa Date: Thu, 18 Jul 2024 16:14:04 +0900 Subject: [PATCH 2/2] Follow review comments --- tslib/react/src/components/PlotImportance.tsx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tslib/react/src/components/PlotImportance.tsx b/tslib/react/src/components/PlotImportance.tsx index 9636c8164..96807d2c7 100644 --- a/tslib/react/src/components/PlotImportance.tsx +++ b/tslib/react/src/components/PlotImportance.tsx @@ -59,6 +59,10 @@ const plotParamImportancesBeta = ( bargroupgap: 0.1, uirevision: "true", template: mode === "dark" ? plotlyDarkTemplate : {}, + legend: { + x: 1.0, + y: 0.95, + }, } if (document.getElementById(plotDomId) === null) {