diff --git a/compute_metrics_reloaded.py b/compute_metrics_reloaded.py index 3d62024..05494ae 100644 --- a/compute_metrics_reloaded.py +++ b/compute_metrics_reloaded.py @@ -32,7 +32,7 @@ The script is compatible with both binary and multi-class segmentation tasks (e.g., nnunet region-based). The metrics are computed for each unique label (class) in the reference (ground truth) image. -Authors: Jan Valosek +Authors: Jan Valosek, Naga Karthik """ @@ -41,6 +41,8 @@ import numpy as np import nibabel as nib import pandas as pd +from multiprocessing import Pool, cpu_count +from functools import partial from MetricsReloaded.metrics.pairwise_measures import BinaryPairwiseMeasures as BPM @@ -81,6 +83,8 @@ def get_parser(): 'see: https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/metrics.html.') parser.add_argument('-output', type=str, default='metrics.csv', required=False, help='Path to the output CSV file to save the metrics. Default: metrics.csv') + parser.add_argument('-jobs', type=int, default=cpu_count()//8, required=False, + help='Number of CPU cores to use in parallel. Default: cpu_count()//8.') return parser @@ -130,9 +134,7 @@ def compute_metrics_single_subject(prediction, reference, metrics): :param metrics: list of metrics to compute """ # load nifti images - print(f'Processing...') - print(f'\tPrediction: {os.path.basename(prediction)}') - print(f'\tReference: {os.path.basename(reference)}') + print(f'\nProcessing:\n\tPrediction: {os.path.basename(prediction)}\n\tReference: {os.path.basename(reference)}') prediction_data = load_nifti_image(prediction) reference_data = load_nifti_image(reference) @@ -159,7 +161,6 @@ def compute_metrics_single_subject(prediction, reference, metrics): # by doing this, we can compute metrics for each label separately, e.g., separately for spinal cord and lesions for label in unique_labels: # create binary masks for the current label - print(f'\tLabel {label}') prediction_data_label = np.array(prediction_data == label, dtype=float) reference_data_label = np.array(reference_data == label, dtype=float) @@ -171,12 +172,9 @@ def compute_metrics_single_subject(prediction, reference, metrics): # add the metrics to the output dictionary metrics_dict[label] = dict_seg - if label == max(unique_labels): - break # break to loop to avoid processing the background label ("else" block) # Special case when both the reference and prediction images are empty else: label = 1 - print(f'\tLabel {label} -- both the reference and prediction are empty') bpm = BPM(prediction_data, reference_data, measures=metrics) dict_seg = bpm.to_dict_meas() @@ -216,8 +214,14 @@ def build_output_dataframe(output_list): return df -def main(): +def process_subject(prediction_file, reference_file, metrics): + """ + Wrapper function to process a single subject. + """ + return compute_metrics_single_subject(prediction_file, reference_file, metrics) + +def main(): # parse command line arguments parser = get_parser() args = parser.parse_args() @@ -227,19 +231,22 @@ def main(): # Print the metrics to be computed print(f'Computing metrics: {args.metrics}') + print(f'Using {args.jobs} CPU cores in parallel ...') # Args.prediction and args.reference are paths to folders with multiple nii.gz files (i.e., MULTIPLE subjects) if os.path.isdir(args.prediction) and os.path.isdir(args.reference): # Get all files in the directories prediction_files, reference_files = get_images_in_folder(args.prediction, args.reference) - # Loop over the subjects - for i in range(len(prediction_files)): - # Compute metrics for each subject - metrics_dict = compute_metrics_single_subject(prediction_files[i], reference_files[i], args.metrics) - # Append the output dictionary (representing a single reference-prediction pair per subject) to the - # output_list - output_list.append(metrics_dict) - # Args.prediction and args.reference are paths nii.gz files from a SINGLE subject + + # Use multiprocessing to parallelize the computation + with Pool(args.jobs) as pool: + # Create a partial function to pass the metrics argument to the process_subject function + func = partial(process_subject, metrics=args.metrics) + # Compute metrics for each subject in parallel + results = pool.starmap(func, zip(prediction_files, reference_files)) + + # Collect the results + output_list.extend(results) else: metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics) # Append the output dictionary (representing a single reference-prediction pair per subject) to the output_list diff --git a/test/test_metrics/test_pairwise_measures_neuropoly.py b/test/test_metrics/test_pairwise_measures_neuropoly.py index 70fa86d..60da2be 100644 --- a/test/test_metrics/test_pairwise_measures_neuropoly.py +++ b/test/test_metrics/test_pairwise_measures_neuropoly.py @@ -287,11 +287,11 @@ def test_non_empty_ref_and_pred_multi_class(self): Multi-class (i.e., voxels with values 1 and 2, e.g., region-based nnUNet training) """ - expected_metrics = {1.0: {'dsc': 0.25, - 'fbeta': 0.2500000055879354, - 'nsd': 0.5, - 'vol_diff': 2.0, - 'rel_vol_error': 200.0, + expected_metrics = {1.0: {'dsc': 0.6521739130434783, + 'fbeta': 0.5769230751596257, + 'nsd': 0.23232323232323232, + 'vol_diff': 2.6, + 'rel_vol_error': 260.0, 'EmptyRef': False, 'EmptyPred': False, 'lesion_ppv': 1.0,