Skip to content

Commit

Permalink
feat: add grey coloring; bugfixing
Browse files Browse the repository at this point in the history
  • Loading branch information
Gooogr committed Jul 23, 2024
1 parent 8f71c92 commit 47f2cab
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 22 deletions.
44 changes: 22 additions & 22 deletions rectools/visuals/metrics_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
WIDGET_HEIGHT = 500
TOP_CHART_MARGIN = 20
DEFAULT_LEGEND_TITLE = "model name"
NAN_COLOR = "grey"


class MetricsApp:
Expand Down Expand Up @@ -217,8 +218,13 @@ def _make_chart_data_avg(self) -> pd.DataFrame:
return metrics_data.merge(meta_data, on=Columns.Model, how="left").reset_index(drop=True)

@staticmethod
def _trim_metadata(raw_string: str, splitter: str = ", ") -> str:
return raw_string.split(splitter, 1)[-1]
@lru_cache
def _trim_metadata(raw_string: str, splitter: str = ", ") -> tp.Tuple[str, str]:
splitted_row = raw_string.split(splitter, 1)
if len(splitted_row) > 1:
meta_value, model_name = splitted_row
return meta_value, model_name
return "", raw_string

def _create_chart_figure(
self,
Expand All @@ -235,6 +241,7 @@ def _create_chart_figure(
scatter_kwargs.update(self.scatter_kwargs)

data = data.sort_values(by=color, ascending=True)
data[color] = data[color].astype(str) # to treat colors values as categorical

fig = px.scatter(
data,
Expand All @@ -244,19 +251,7 @@ def _create_chart_figure(
symbol=Columns.Model,
**scatter_kwargs,
)

# Add meta-info in legend
for trace in fig.data:
if color != Columns.Model:
meta_value = data.set_index(Columns.Model).at[trace.name, color]
trace.name = f"{meta_value}, {trace.name}"

layout_params = {
"margin": {"t": TOP_CHART_MARGIN},
"legend_title": legend_title,
"showlegend": self.show_legend,
}
fig.update_layout(layout_params)
fig.update_layout(margin={"t": TOP_CHART_MARGIN}, legend_title=legend_title, showlegend=self.show_legend)
fig.update_coloraxes(showscale=False)
return fig

Expand All @@ -275,16 +270,18 @@ def _update_figure_widget(

# Save dots symbols from the previous widget state
# Remove metainfo from trace name. Thus we guarantee to map with traces from previous state
trace_name2symbol = {self._trim_metadata(trace.name): trace.marker.symbol for trace in self.fig.data}
trace_name2symbol = {self._trim_metadata(trace.name)[1]: trace.marker.symbol for trace in self.fig.data}
legend_title = f"{meta_feature.value}, {DEFAULT_LEGEND_TITLE}" if use_meta.value else DEFAULT_LEGEND_TITLE
self.fig = self._create_chart_figure(chart_data, metric_x.value, metric_y.value, color_clmn, legend_title)

for trace in self.fig.data:
trace_name = self._trim_metadata(trace.name)
trace_name = self._trim_metadata(trace.name)[1]
trace.marker.symbol = trace_name2symbol[trace_name]

with fig_widget.batch_update():
for idx, trace in enumerate(self.fig.data):
if self._trim_metadata(trace.name)[0] == "nan":
trace.marker.color = NAN_COLOR
fig_widget.data[idx].x = trace.x
fig_widget.data[idx].y = trace.y
fig_widget.data[idx].marker.color = trace.marker.color
Expand All @@ -295,7 +292,7 @@ def _update_figure_widget(
fig_widget.data[idx].hoverinfo = trace.hoverinfo
fig_widget.data[idx].hovertemplate = trace.hovertemplate

fig_widget.layout.update(self.fig.layout)
fig_widget.layout = self.fig.layout
self.fig.layout.margin = None # keep separate chart non-truncated

def _update_fold_visibility(self, use_avg: widgets.Checkbox, fold_i: widgets.Dropdown) -> None:
Expand Down Expand Up @@ -329,10 +326,13 @@ def display(self) -> None:
)

# Initialize go.FigureWidget initial chart state
chart_data = self._create_chart_data(use_avg, fold_i)
legend_title = f"{meta_feature.value}, {DEFAULT_LEGEND_TITLE}" if use_meta.value else DEFAULT_LEGEND_TITLE
self.fig = self._create_chart_figure(chart_data, metric_x.value, metric_y.value, Columns.Model, legend_title)
fig_widget = go.FigureWidget(data=self.fig.data, layout=self.fig.layout)
if not self.fig.data:
chart_data = self._create_chart_data(use_avg, fold_i)
legend_title = f"{meta_feature.value}, {DEFAULT_LEGEND_TITLE}" if use_meta.value else DEFAULT_LEGEND_TITLE
self.fig = self._create_chart_figure(
chart_data, metric_x.value, metric_y.value, Columns.Model, legend_title
)
fig_widget = go.FigureWidget(data=self.fig.data, layout=self.fig.layout)

def update(event: tp.Callable[..., tp.Any]) -> None: # pragma: no cover
self._update_figure_widget(fig_widget, metric_x, metric_y, use_avg, fold_i, meta_feature, use_meta)
Expand Down
22 changes: 22 additions & 0 deletions tests/visuals/test_metrics_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,25 @@ def test_make_chart_data_avg_no_metadata(self) -> None:
}
)
pd.testing.assert_frame_equal(chart_data, expected_data)

# -----------------------------------------Test helper methods------------------------------------------ #

def test_trim_metadata_with_meta(self) -> None:
app = MetricsApp.construct(
models_metrics=DF_METRICS,
models_metadata=None,
auto_display=False,
)
test_string = "10, random"
expected_result = ("10", "random")
assert app._trim_metadata(test_string) == expected_result

def test_trim_metadata_without_meta(self) -> None:
app = MetricsApp.construct(
models_metrics=DF_METRICS,
models_metadata=None,
auto_display=False,
)
test_string = "random"
expected_result = ("", "random")
assert app._trim_metadata(test_string) == expected_result

0 comments on commit 47f2cab

Please sign in to comment.