diff --git a/marl_eval/plotting_tools/plot_utils.py b/marl_eval/plotting_tools/plot_utils.py index 02b34948..de22748d 100644 --- a/marl_eval/plotting_tools/plot_utils.py +++ b/marl_eval/plotting_tools/plot_utils.py @@ -36,6 +36,7 @@ def plot_single_task_curve( ticklabelsize: str = "xx-large", legend_map: Optional[Dict] = None, run_times: Optional[Dict] = None, + fix_normed_axis: bool = False, **kwargs: Any, ) -> Figure: """Plots an aggregate metric with CIs as a function of environment frames. @@ -60,6 +61,7 @@ def plot_single_task_curve( If None, then this mapping is created based on `algorithms`. run_times: Dictionary that maps each algorithm to the number of seconds it took to run. If None, then environment steps will be displayed. + fix_normed_axis: If the metric is normalised, fix the y-axis from 0 to 1. **kwargs: Arbitrary keyword arguments. Returns: @@ -78,6 +80,8 @@ def plot_single_task_curve( marker = kwargs.pop("marker", "o") linewidth = kwargs.pop("linewidth", 2) + highest_upper_val = 1 + for algorithm in algorithms: x_axis_len = len(aggregated_data[algorithm]["mean"]) @@ -94,6 +98,9 @@ def plot_single_task_curve( metric_values + confidence_interval, ) + if fix_normed_axis is True and highest_upper_val < np.max(upper): + highest_upper_val = np.max(upper) + if legend_map is not None: algorithm_name = legend_map[algorithm] else: @@ -111,6 +118,9 @@ def plot_single_task_curve( x_axis_values, y1=lower, y2=upper, color=colors[algorithm], alpha=0.2 ) + if fix_normed_axis is True: + ax.set_ylim(0, highest_upper_val) + return _annotate_and_decorate_axis( ax, xlabel=xlabel, diff --git a/marl_eval/plotting_tools/plotting.py b/marl_eval/plotting_tools/plotting.py index 49b33e5f..98e45a8a 100644 --- a/marl_eval/plotting_tools/plotting.py +++ b/marl_eval/plotting_tools/plotting.py @@ -420,6 +420,8 @@ def plot_single_task( metric_name, task_name, environment_name, metrics_to_normalize ) + fix_normed_axis = False + task_mean_ci_data = get_and_aggregate_data_single_task( processed_data=processed_data, environment_name=environment_name, @@ -430,6 +432,7 @@ def plot_single_task( if metric_name in metrics_to_normalize: ylabel = "Normalized " + " ".join(metric_name.split("_")) + fix_normed_axis = True else: ylabel = " ".join(metric_name.split("_")).capitalize() @@ -460,6 +463,7 @@ def plot_single_task( legend_map=legend_map, run_times=run_times, marker="", + fix_normed_axis=fix_normed_axis, ) return fig diff --git a/marl_eval/utils/data_processing_utils.py b/marl_eval/utils/data_processing_utils.py index a87f472d..1841d69a 100644 --- a/marl_eval/utils/data_processing_utils.py +++ b/marl_eval/utils/data_processing_utils.py @@ -99,7 +99,7 @@ def get_and_aggregate_data_single_task( Args: processed_data: Dictionary containing processed data. metric_name: Name of metric to aggregate. - metrics_to_normalize: List of metrics to normalize. + metrics_to_normalize: List of metrics to normalise. task_name: Name of task to aggregate. environment_name: Name of environment to aggregate. """ @@ -150,6 +150,8 @@ def get_and_aggregate_data_single_task( def data_process_pipeline( # noqa: C901 raw_data: Dict[str, Dict[str, Any]], metrics_to_normalize: List[str], + custom_min: Dict[str, Dict[str, float]] = {}, + custom_max: Dict[str, Dict[str, float]] = {}, ) -> Dict[str, Dict[str, Any]]: """Function for processing raw input experiment data. @@ -159,6 +161,10 @@ def data_process_pipeline( # noqa: C901 metrics_to_normalize: A list of metric names for metrics that should be min/max normalised. These metric names should match the names as given in the raw dataset. + custom_min (optional): Dictionary containing custom global minimum values + for normalisation. Keys are task names and values are floats. + custom_max (optional): Dictionary containing custom global maximum values + for normalisation. Keys are task names and values are floats. Returns: processed_data: Dictionary containing processed experiment data where relevant @@ -246,7 +252,7 @@ def _compare_values( f"mean_{metric}" ] = mean if metric in metrics_to_normalize: - # Normalization + # Normalisation metric_array = np.array(metrics[metric]) metric_global_min = metric_min_max_info[metric][ "global_min" @@ -254,6 +260,17 @@ def _compare_values( metric_global_max = metric_min_max_info[metric][ "global_max" ] + # Use the custom min or max if given + if ( + task in custom_min.keys() + and metric in custom_min[task].keys() + ): + metric_global_min = custom_min[task][metric] + if ( + task in custom_max.keys() + and metric in custom_max[task].keys() + ): + metric_global_max = custom_max[task][metric] normed_metric_array = ( metric_array - metric_global_min ) / (