Skip to content

Commit

Permalink
Merge pull request #899 from keisuke-umezawa/feature/use-tslib-plotim…
Browse files Browse the repository at this point in the history
…portance

Use tslib PlotImportance in optuna-dashboard
  • Loading branch information
porink0424 authored Jul 19, 2024
2 parents f699d99 + 3e353e6 commit 249906e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 121 deletions.
129 changes: 20 additions & 109 deletions optuna_dashboard/ts/components/GraphHyperparameterImportances.tsx
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import { Box, Card, CardContent, Typography, useTheme } from "@mui/material"
import { Box, Card, CardContent } from "@mui/material"
import * as plotly from "plotly.js-dist-min"
import React, { FC, useEffect } from "react"

import { ParamImportance, StudyDetail } from "ts/types/optuna"
import { PlotImportance } from "@optuna/react"
import { StudyDetail } from "ts/types/optuna"
import { PlotType } from "../apiClient"
import { useParamImportance } from "../hooks/useParamImportance"
import { usePlot } from "../hooks/usePlot"
import {
useBackendRender,
usePlotlyColorTheme,
useStudyDirections,
} from "../state"
import { useBackendRender } from "../state"

const plotDomId = "graph-hyperparameter-importances"

Expand All @@ -19,6 +16,13 @@ export const GraphHyperparameterImportance: FC<{
study: StudyDetail | null
graphHeight: string
}> = ({ studyId, study = null, graphHeight }) => {
const numCompletedTrials =
study?.trials.filter((t) => t.state === "Complete").length || 0
const { importances } = useParamImportance({
numCompletedTrials,
studyId,
})

if (useBackendRender()) {
return (
<GraphHyperparameterImportanceBackend
Expand All @@ -29,11 +33,15 @@ export const GraphHyperparameterImportance: FC<{
)
} else {
return (
<GraphHyperparameterImportanceFrontend
studyId={studyId}
study={study}
graphHeight={graphHeight}
/>
<Card>
<CardContent>
<PlotImportance
study={study}
importance={importances}
graphHeight={graphHeight}
/>
</CardContent>
</Card>
)
}
}
Expand Down Expand Up @@ -64,100 +72,3 @@ const GraphHyperparameterImportanceBackend: FC<{

return <Box component="div" id={plotDomId} sx={{ height: graphHeight }} />
}

const GraphHyperparameterImportanceFrontend: FC<{
studyId: number
study: StudyDetail | null
graphHeight: string
}> = ({ studyId, study = null, graphHeight }) => {
const theme = useTheme()
const colorTheme = usePlotlyColorTheme(theme.palette.mode)

const numCompletedTrials =
study?.trials.filter((t) => t.state === "Complete").length || 0
const { importances } = useParamImportance({
numCompletedTrials,
studyId,
})
const nObjectives = useStudyDirections(studyId)?.length
const objectiveNames: string[] =
study?.objective_names ||
study?.directions.map((d, i) => `Objective ${i}`) ||
[]

useEffect(() => {
if (importances !== undefined && nObjectives === importances.length) {
plotParamImportance(importances, objectiveNames, colorTheme)
}
}, [nObjectives, importances, colorTheme])

return (
<Card>
<CardContent>
<Typography
variant="h6"
sx={{ margin: "1em 0", fontWeight: theme.typography.fontWeightBold }}
>
Hyperparameter Importance
</Typography>
<Box component="div" id={plotDomId} sx={{ height: graphHeight }} />
</CardContent>
</Card>
)
}

const plotParamImportance = (
importances: ParamImportance[][],
objectiveNames: string[],
colorTheme: Partial<Plotly.Template>
) => {
const layout: Partial<plotly.Layout> = {
xaxis: {
title: "Hyperparameter Importance",
},
yaxis: {
title: "Hyperparameter",
automargin: true,
},
margin: {
l: 50,
t: 0,
r: 50,
b: 50,
},
barmode: "group",
bargap: 0.15,
bargroupgap: 0.1,
uirevision: "true",
template: colorTheme,
legend: {
x: 1.0,
y: 0.95,
},
}

if (document.getElementById(plotDomId) === null) {
return
}
const traces: Partial<plotly.PlotData>[] = importances.map(
(importance, i) => {
const reversed = [...importance].reverse()
const importance_values = reversed.map((p) => p.importance)
const param_names = reversed.map((p) => p.name)
const param_hover_templates = reversed.map(
(p) => `${p.name} (${p.distribution}): ${p.importance} <extra></extra>`
)
return {
type: "bar",
orientation: "h",
name: objectiveNames[i],
x: importance_values,
y: param_names,
text: importance_values.map((v) => String(v.toFixed(2))),
textposition: "outside",
hovertemplate: param_hover_templates,
}
}
)
plotly.react(plotDomId, traces, layout)
}
4 changes: 1 addition & 3 deletions standalone_app/src/components/StudyDetail.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ export const StudyDetail: FC<{
<Grid item xs={6}>
<Card sx={{ margin: theme.spacing(2) }}>
<CardContent>
{!!study && (
<PlotImportance study={study} importance={importance} />
)}
<PlotImportance study={study} importance={importance} />
</CardContent>
</Card>
</Grid>
Expand Down
23 changes: 14 additions & 9 deletions tslib/react/src/components/PlotImportance.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ import { plotlyDarkTemplate } from "./PlotlyDarkMode"
const plotDomId = "graph-hyperparameter-importances"

export const PlotImportance: FC<{
study: Optuna.Study
importance: Optuna.ParamImportance[][]
}> = ({ study, importance }) => {
study: Optuna.Study | null
importance?: Optuna.ParamImportance[][]
graphHeight?: string
}> = ({ study = null, importance, graphHeight = "450px" }) => {
const theme = useTheme()
const objectiveNames: string[] = study.directions.map(
(_d, i) => `Objective ${i}`
)
const objectiveNames: string[] = study
? study.directions.map((_d, i) => `Objective ${i}`)
: []

useEffect(() => {
if (importance.length > 0) {
if (study !== null && importance !== undefined && importance.length > 0) {
plotParamImportancesBeta(importance, objectiveNames, theme.palette.mode)
}
}, [objectiveNames, importance, theme.palette.mode])
}, [study, objectiveNames, importance, theme.palette.mode])

return (
<>
Expand All @@ -29,7 +30,7 @@ export const PlotImportance: FC<{
>
Hyperparameter Importance
</Typography>
<Box id={plotDomId} sx={{ height: "450px" }} />
<Box id={plotDomId} sx={{ height: graphHeight }} />
</>
)
}
Expand Down Expand Up @@ -58,6 +59,10 @@ const plotParamImportancesBeta = (
bargroupgap: 0.1,
uirevision: "true",
template: mode === "dark" ? plotlyDarkTemplate : {},
legend: {
x: 1.0,
y: 0.95,
},
}

if (document.getElementById(plotDomId) === null) {
Expand Down

0 comments on commit 249906e

Please sign in to comment.