Skip to content

feat: add custom min/max to task-level plots #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions marl_eval/plotting_tools/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def plot_single_task_curve(
ticklabelsize: str = "xx-large",
legend_map: Optional[Dict] = None,
run_times: Optional[Dict] = None,
fix_normed_axis: bool = False,
**kwargs: Any,
) -> Figure:
"""Plots an aggregate metric with CIs as a function of environment frames.
Expand All @@ -60,6 +61,7 @@ def plot_single_task_curve(
If None, then this mapping is created based on `algorithms`.
run_times: Dictionary that maps each algorithm to the number of seconds it
took to run. If None, then environment steps will be displayed.
fix_normed_axis: If the metric is normalised, fix the y-axis from 0 to 1.
**kwargs: Arbitrary keyword arguments.

Returns:
Expand All @@ -78,6 +80,8 @@ def plot_single_task_curve(
marker = kwargs.pop("marker", "o")
linewidth = kwargs.pop("linewidth", 2)

highest_upper_val = 1

for algorithm in algorithms:
x_axis_len = len(aggregated_data[algorithm]["mean"])

Expand All @@ -94,6 +98,9 @@ def plot_single_task_curve(
metric_values + confidence_interval,
)

if fix_normed_axis is True and highest_upper_val < np.max(upper):
highest_upper_val = np.max(upper)

if legend_map is not None:
algorithm_name = legend_map[algorithm]
else:
Expand All @@ -111,6 +118,9 @@ def plot_single_task_curve(
x_axis_values, y1=lower, y2=upper, color=colors[algorithm], alpha=0.2
)

if fix_normed_axis is True:
ax.set_ylim(0, highest_upper_val)

return _annotate_and_decorate_axis(
ax,
xlabel=xlabel,
Expand Down
4 changes: 4 additions & 0 deletions marl_eval/plotting_tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ def plot_single_task(
metric_name, task_name, environment_name, metrics_to_normalize
)

fix_normed_axis = False

task_mean_ci_data = get_and_aggregate_data_single_task(
processed_data=processed_data,
environment_name=environment_name,
Expand All @@ -430,6 +432,7 @@ def plot_single_task(

if metric_name in metrics_to_normalize:
ylabel = "Normalized " + " ".join(metric_name.split("_"))
fix_normed_axis = True
else:
ylabel = " ".join(metric_name.split("_")).capitalize()

Expand Down Expand Up @@ -460,6 +463,7 @@ def plot_single_task(
legend_map=legend_map,
run_times=run_times,
marker="",
fix_normed_axis=fix_normed_axis,
)

return fig
21 changes: 19 additions & 2 deletions marl_eval/utils/data_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_and_aggregate_data_single_task(
Args:
processed_data: Dictionary containing processed data.
metric_name: Name of metric to aggregate.
metrics_to_normalize: List of metrics to normalize.
metrics_to_normalize: List of metrics to normalise.
task_name: Name of task to aggregate.
environment_name: Name of environment to aggregate.
"""
Expand Down Expand Up @@ -150,6 +150,8 @@ def get_and_aggregate_data_single_task(
def data_process_pipeline( # noqa: C901
raw_data: Dict[str, Dict[str, Any]],
metrics_to_normalize: List[str],
custom_min: Dict[str, Dict[str, float]] = {},
custom_max: Dict[str, Dict[str, float]] = {},
) -> Dict[str, Dict[str, Any]]:
"""Function for processing raw input experiment data.

Expand All @@ -159,6 +161,10 @@ def data_process_pipeline( # noqa: C901
metrics_to_normalize: A list of metric names for metrics that should
be min/max normalised. These metric names should match the names as
given in the raw dataset.
custom_min (optional): Dictionary containing custom global minimum values
for normalisation. Keys are task names and values are floats.
custom_max (optional): Dictionary containing custom global maximum values
for normalisation. Keys are task names and values are floats.

Returns:
processed_data: Dictionary containing processed experiment data where relevant
Expand Down Expand Up @@ -246,14 +252,25 @@ def _compare_values(
f"mean_{metric}"
] = mean
if metric in metrics_to_normalize:
# Normalization
# Normalisation
metric_array = np.array(metrics[metric])
metric_global_min = metric_min_max_info[metric][
"global_min"
]
metric_global_max = metric_min_max_info[metric][
"global_max"
]
# Use the custom min or max if given
if (
task in custom_min.keys()
and metric in custom_min[task].keys()
):
metric_global_min = custom_min[task][metric]
if (
task in custom_max.keys()
and metric in custom_max[task].keys()
):
metric_global_max = custom_max[task][metric]
normed_metric_array = (
metric_array - metric_global_min
) / (
Expand Down
Loading