Skip to content

Commit

Permalink
Feat (dashboard): class distribution plots (#122)
Browse files Browse the repository at this point in the history
* refactor: separate summary dashboards

* remove unused imports

* refactor: merge duplicate code segments

* fix: markdown

* fix: fix wrong variable

* fix: add title

* fix: improve performance

* fix: styling

* refactor: col name

* feat: collect all label type statsitics

* feat: create generic label distribution chart

* feat: add class distributions to the dashboard

* feat: mark undersampled label types

* fix: styling

* feat: use pandera for df

* fix: remove print

* fix: make distributions expandable

* fix: style

* refactor: create AnnotatorStatistics dataclass

* Fix revisions

* fix: style
  • Loading branch information
Gorkem-Encord authored Jan 27, 2023
1 parent 76952d9 commit e4955d9
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 17 deletions.
42 changes: 34 additions & 8 deletions src/encord_active/app/common/components/metric_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from encord_active.app.common.state import get_state
from encord_active.lib.charts.data_quality_summary import (
create_image_size_distribution_chart,
create_labels_distribution_chart,
create_outlier_distribution_chart,
)
from encord_active.lib.common.image_utils import show_image_and_draw_polygons
Expand Down Expand Up @@ -91,14 +92,14 @@ def render_data_quality_dashboard(severe_outlier_color: str, moderate_outlier_co
)

st.write("")
outliers_plotting_col, issues_col = st.columns([6, 3])
plots_col, issues_col = st.columns([6, 3])

if get_state().metrics_data_summary.total_unique_severe_outliers > 0:
fig = create_outlier_distribution_chart(all_metrics_outliers, severe_outlier_color, moderate_outlier_color)
outliers_plotting_col.plotly_chart(fig, use_container_width=True)
plots_col.plotly_chart(fig, use_container_width=True)

fig = create_image_size_distribution_chart(get_state().image_sizes)
outliers_plotting_col.plotly_chart(fig, use_container_width=True)
plots_col.plotly_chart(fig, use_container_width=True)

metrics_with_severe_outliers = all_metrics_outliers[all_metrics_outliers["total_severe_outliers"] > 0]
render_issues_pane(metrics_with_severe_outliers, issues_col)
Expand All @@ -107,7 +108,7 @@ def render_data_quality_dashboard(severe_outlier_color: str, moderate_outlier_co
def render_label_quality_dashboard(severe_outlier_color: str, moderate_outlier_color: str, background_color: str):

if get_state().annotation_sizes is None:
get_state().annotation_sizes = get_all_annotation_numbers(get_state().project_paths.project_dir)
get_state().annotation_sizes = get_all_annotation_numbers(get_state().project_paths)

metrics = load_available_metrics(get_state().project_paths.metrics, MetricScope.LABEL_QUALITY)
if get_state().metrics_label_summary is None:
Expand All @@ -123,14 +124,16 @@ def render_label_quality_dashboard(severe_outlier_color: str, moderate_outlier_c
) = st.columns(4)

total_object_annotations_col.markdown(
summary_item("Object annotations", get_state().annotation_sizes[1], background_color=background_color),
summary_item(
"Object annotations", get_state().annotation_sizes.total_object_labels, background_color=background_color
),
unsafe_allow_html=True,
)

total_classification_annotations_col.markdown(
summary_item(
"Classification annotations",
get_state().annotation_sizes[0],
get_state().annotation_sizes.total_classification_labels,
background_color=background_color,
),
unsafe_allow_html=True,
Expand All @@ -155,11 +158,34 @@ def render_label_quality_dashboard(severe_outlier_color: str, moderate_outlier_c
)

st.write("")
outliers_plotting_col, issues_col = st.columns([6, 3])
plots_col, issues_col = st.columns([6, 3])

if get_state().metrics_label_summary.total_unique_severe_outliers > 0:
fig = create_outlier_distribution_chart(all_metrics_outliers, severe_outlier_color, moderate_outlier_color)
outliers_plotting_col.plotly_chart(fig, use_container_width=True)
plots_col.plotly_chart(fig, use_container_width=True)

# label distribution plots
with plots_col.expander("Labels distribution", expanded=True):
if (
get_state().annotation_sizes.total_object_labels > 0
or get_state().annotation_sizes.total_classification_labels > 0
):
st.info("If a class's size is lower than half of the median value, it is indicated as 'undersampled'.")

if get_state().annotation_sizes.total_object_labels > 0:
fig = create_labels_distribution_chart(
get_state().annotation_sizes.objects, "Objects distributions", "Object"
)
st.plotly_chart(fig, use_container_width=True)

for (
classification_question_name,
classification_question_answers,
) in get_state().annotation_sizes.classifications.items():
fig = create_labels_distribution_chart(
classification_question_answers, classification_question_name, "Class"
)
st.plotly_chart(fig, use_container_width=True)

metrics_with_severe_outliers = all_metrics_outliers[all_metrics_outliers["total_severe_outliers"] > 0]
render_issues_pane(metrics_with_severe_outliers, issues_col)
Expand Down
3 changes: 2 additions & 1 deletion src/encord_active/app/common/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pandera.typing import DataFrame

from encord_active.lib.dataset.outliers import MetricsSeverity
from encord_active.lib.dataset.summary_utils import AnnotationStatistics
from encord_active.lib.db.merged_metrics import MergedMetrics
from encord_active.lib.db.tags import Tag, Tags
from encord_active.lib.metrics.utils import MetricData
Expand Down Expand Up @@ -59,7 +60,7 @@ class State:
predictions = PredictionsState()
similarities_count = 8
image_sizes: Optional[np.ndarray] = None
annotation_sizes: Optional[tuple[int, int]] = None
annotation_sizes: Optional[AnnotationStatistics] = None
metrics_data_summary: Optional[MetricsSeverity] = None
metrics_label_summary: Optional[MetricsSeverity] = None

Expand Down
55 changes: 54 additions & 1 deletion src/encord_active/lib/charts/data_quality_summary.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import numpy as np
import pandas as pd
import pandera as pa
import plotly.express as px
import plotly.graph_objects as go
from pandera.typing import DataFrame, Series


class LabelStatisticsSchema(pa.SchemaModel):
name: Series[str] = pa.Field()
count: Series[int] = pa.Field()
status: Series[bool] = pa.Field()


def create_outlier_distribution_chart(
Expand All @@ -25,7 +33,52 @@ def create_outlier_distribution_chart(
)
)
fig.update_layout(
title_text="Outliers", height=400, legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
margin=dict(l=0, r=0, b=0),
title_text="Outliers",
height=400,
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
)

return fig


def create_labels_distribution_chart(
labels: dict, title: str, x_title: str = "Class", y_title: str = "Count"
) -> go.Figure:
labels_df = pd.DataFrame.from_dict(labels, orient="index").reset_index()
labels_df.rename(columns={"index": "name", 0: "count"}, inplace=True)
labels_df.insert(0, "status", False)

labels_df = DataFrame[LabelStatisticsSchema](labels_df)
labels_df.sort_values(by=LabelStatisticsSchema.count, ascending=False, inplace=True)

Q2 = labels_df[LabelStatisticsSchema.count].quantile(0.5)

labels_df.loc[labels_df[LabelStatisticsSchema.count] <= (Q2 * 0.5), LabelStatisticsSchema.status] = True

fig = go.Figure(
data=[
go.Bar(
x=labels_df.loc[labels_df[LabelStatisticsSchema.status] == False][LabelStatisticsSchema.name],
y=labels_df[labels_df[LabelStatisticsSchema.status] == False][LabelStatisticsSchema.count],
name="representative",
marker_color="#3380FF",
),
go.Bar(
x=labels_df.loc[labels_df[LabelStatisticsSchema.status] == True][LabelStatisticsSchema.name],
y=labels_df[labels_df[LabelStatisticsSchema.status] == True][LabelStatisticsSchema.count],
name="undersampled",
marker_color="tomato",
),
]
)

fig.update_layout(
margin=dict(l=0, r=0, b=0),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
title=title,
xaxis_title=x_title,
yaxis_title=y_title,
)

return fig
Expand Down
71 changes: 64 additions & 7 deletions src/encord_active/lib/dataset/summary_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Tuple

import numpy as np
from encord.constants.enums import DataType
from encord.objects.ontology_structure import OntologyStructure

from encord_active.lib.project import ProjectFileStructure


@dataclass
class AnnotationStatistics:
objects: dict = field(default_factory=dict)
classifications: dict = field(default_factory=dict)
total_object_labels: int = 0
total_classification_labels: int = 0


from encord_active.lib.dataset.outliers import (
MetricOutlierInfo,
Expand Down Expand Up @@ -41,22 +53,67 @@ def get_median_value_of_2d_array(array: np.ndarray) -> np.ndarray:
return array[item_index[0][0], :]


def get_all_annotation_numbers(project_folder: Path) -> Tuple[int, int]:
def get_all_annotation_numbers(project_paths: ProjectFileStructure) -> AnnotationStatistics:
"""
returns (number of classification label, number of object label)
Returns label statistics for both objects and classifications. Does not count nested
labels, only counts the immediate labels.
"""

labels: AnnotationStatistics = AnnotationStatistics()
classification_label_counter = 0
object_label_counter = 0

for label_row in (project_folder / "data").iterdir():
project_ontology = json.loads((project_paths.ontology).read_text(encoding="utf-8"))
ontology = OntologyStructure.from_dict(project_ontology)

for object_item in ontology.objects:
labels.objects[object_item.name] = 0
for classification_item in ontology.classifications:
labels.classifications[classification_item.attributes[0].name] = {}

# For radio and checkbox types
if hasattr(classification_item.attributes[0], "options"):
for option in classification_item.attributes[0].options:
labels.classifications[classification_item.attributes[0].name][option.label] = 0

for label_row in (project_paths.data).iterdir():
if (label_row / "label_row.json").exists():
label_row_meta = json.loads((label_row / "label_row.json").read_text(encoding="utf-8"))
if label_row_meta["data_type"] in [DataType.IMAGE.value, DataType.IMG_GROUP.value]:
for data_unit in label_row_meta["data_units"].values():
object_label_counter += len(data_unit["labels"]["objects"])
classification_label_counter += len(data_unit["labels"]["classifications"])

return classification_label_counter, object_label_counter
object_label_counter += len(data_unit["labels"].get("objects", []))
classification_label_counter += len(data_unit["labels"].get("classifications", []))

for object_ in data_unit["labels"].get("objects", []):
if object_["name"] not in labels.objects:
print(f'Object name "{object_["name"]}" is not exist in project ontology')
labels.objects[object_["name"]] += 1

for classification in data_unit["labels"].get("classifications", []):
classificationHash = classification["classificationHash"]
classification_answer_item = label_row_meta["classification_answers"][classificationHash][
"classifications"
][0]
classification_question_name = classification_answer_item["name"]
if classification_question_name in labels.classifications:

if isinstance(classification_answer_item["answers"], list):
for answer_item in classification_answer_item["answers"]:
if answer_item["name"] in labels.classifications[classification_question_name]:
labels.classifications[classification_question_name][answer_item["name"]] += 1
elif isinstance(classification_answer_item["answers"], str):
labels.classifications[classification_question_name].setdefault(
classification_answer_item["answers"], 0
)
labels.classifications[classification_question_name][
classification_answer_item["answers"]
] += 1

labels.total_object_labels = object_label_counter
labels.total_classification_labels = classification_label_counter

return labels


def get_metric_summary(metrics: list[MetricData]) -> MetricsSeverity:
Expand Down

0 comments on commit e4955d9

Please sign in to comment.