Skip to content

Commit

Permalink
Merge pull request #11 from ivadomed/nk/faster-metrics-computation
Browse files Browse the repository at this point in the history
Speed up metrics computation by parallelizing across subjects
  • Loading branch information
naga-karthik authored Jun 11, 2024
2 parents 3fcec73 + 06f66b0 commit 76dbb55
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
41 changes: 24 additions & 17 deletions compute_metrics_reloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
The script is compatible with both binary and multi-class segmentation tasks (e.g., nnunet region-based).
The metrics are computed for each unique label (class) in the reference (ground truth) image.
Authors: Jan Valosek
Authors: Jan Valosek, Naga Karthik
"""


Expand All @@ -41,6 +41,8 @@
import numpy as np
import nibabel as nib
import pandas as pd
from multiprocessing import Pool, cpu_count
from functools import partial

from MetricsReloaded.metrics.pairwise_measures import BinaryPairwiseMeasures as BPM

Expand Down Expand Up @@ -81,6 +83,8 @@ def get_parser():
'see: https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/metrics.html.')
parser.add_argument('-output', type=str, default='metrics.csv', required=False,
help='Path to the output CSV file to save the metrics. Default: metrics.csv')
parser.add_argument('-jobs', type=int, default=cpu_count()//8, required=False,
help='Number of CPU cores to use in parallel. Default: cpu_count()//8.')

return parser

Expand Down Expand Up @@ -130,9 +134,7 @@ def compute_metrics_single_subject(prediction, reference, metrics):
:param metrics: list of metrics to compute
"""
# load nifti images
print(f'Processing...')
print(f'\tPrediction: {os.path.basename(prediction)}')
print(f'\tReference: {os.path.basename(reference)}')
print(f'\nProcessing:\n\tPrediction: {os.path.basename(prediction)}\n\tReference: {os.path.basename(reference)}')
prediction_data = load_nifti_image(prediction)
reference_data = load_nifti_image(reference)

Expand All @@ -159,7 +161,6 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# by doing this, we can compute metrics for each label separately, e.g., separately for spinal cord and lesions
for label in unique_labels:
# create binary masks for the current label
print(f'\tLabel {label}')
prediction_data_label = np.array(prediction_data == label, dtype=float)
reference_data_label = np.array(reference_data == label, dtype=float)

Expand All @@ -171,12 +172,9 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# add the metrics to the output dictionary
metrics_dict[label] = dict_seg

if label == max(unique_labels):
break # break to loop to avoid processing the background label ("else" block)
# Special case when both the reference and prediction images are empty
else:
label = 1
print(f'\tLabel {label} -- both the reference and prediction are empty')
bpm = BPM(prediction_data, reference_data, measures=metrics)
dict_seg = bpm.to_dict_meas()

Expand Down Expand Up @@ -216,8 +214,14 @@ def build_output_dataframe(output_list):
return df


def main():
def process_subject(prediction_file, reference_file, metrics):
"""
Wrapper function to process a single subject.
"""
return compute_metrics_single_subject(prediction_file, reference_file, metrics)


def main():
# parse command line arguments
parser = get_parser()
args = parser.parse_args()
Expand All @@ -227,19 +231,22 @@ def main():

# Print the metrics to be computed
print(f'Computing metrics: {args.metrics}')
print(f'Using {args.jobs} CPU cores in parallel ...')

# Args.prediction and args.reference are paths to folders with multiple nii.gz files (i.e., MULTIPLE subjects)
if os.path.isdir(args.prediction) and os.path.isdir(args.reference):
# Get all files in the directories
prediction_files, reference_files = get_images_in_folder(args.prediction, args.reference)
# Loop over the subjects
for i in range(len(prediction_files)):
# Compute metrics for each subject
metrics_dict = compute_metrics_single_subject(prediction_files[i], reference_files[i], args.metrics)
# Append the output dictionary (representing a single reference-prediction pair per subject) to the
# output_list
output_list.append(metrics_dict)
# Args.prediction and args.reference are paths nii.gz files from a SINGLE subject

# Use multiprocessing to parallelize the computation
with Pool(args.jobs) as pool:
# Create a partial function to pass the metrics argument to the process_subject function
func = partial(process_subject, metrics=args.metrics)
# Compute metrics for each subject in parallel
results = pool.starmap(func, zip(prediction_files, reference_files))

# Collect the results
output_list.extend(results)
else:
metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics)
# Append the output dictionary (representing a single reference-prediction pair per subject) to the output_list
Expand Down
10 changes: 5 additions & 5 deletions test/test_metrics/test_pairwise_measures_neuropoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,11 @@ def test_non_empty_ref_and_pred_multi_class(self):
Multi-class (i.e., voxels with values 1 and 2, e.g., region-based nnUNet training)
"""

expected_metrics = {1.0: {'dsc': 0.25,
'fbeta': 0.2500000055879354,
'nsd': 0.5,
'vol_diff': 2.0,
'rel_vol_error': 200.0,
expected_metrics = {1.0: {'dsc': 0.6521739130434783,
'fbeta': 0.5769230751596257,
'nsd': 0.23232323232323232,
'vol_diff': 2.6,
'rel_vol_error': 260.0,
'EmptyRef': False,
'EmptyPred': False,
'lesion_ppv': 1.0,
Expand Down

0 comments on commit 76dbb55

Please sign in to comment.