From 47ecfa28b151b7bb4b7e5f0fb3034e3764cb429e Mon Sep 17 00:00:00 2001 From: xiangyan93 Date: Thu, 14 Dec 2023 15:04:05 -0500 Subject: [PATCH] Update to 1.1.0: add 3 metrics for regression tasks: spearman's r, pearson's r, and kendall's tau. --- mgktools/__init__.py | 2 +- mgktools/evaluators/metric.py | 9 ++++++++- test/cross_validation/test_cv_pure.py | 4 ++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/mgktools/__init__.py b/mgktools/__init__.py index 24de858..7c0640f 100644 --- a/mgktools/__init__.py +++ b/mgktools/__init__.py @@ -2,4 +2,4 @@ # -*- coding: utf-8 -*- -__version__ = '1.0.0' +__version__ = '1.1.0' diff --git a/mgktools/evaluators/metric.py b/mgktools/evaluators/metric.py index 1e1a785..3cd4976 100644 --- a/mgktools/evaluators/metric.py +++ b/mgktools/evaluators/metric.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- from typing import Dict, Iterator, List, Optional, Union, Literal, Tuple import numpy as np +import scipy from sklearn.metrics import ( mean_squared_error, mean_absolute_error, @@ -14,7 +15,7 @@ matthews_corrcoef ) Metric = Literal['roc-auc', 'accuracy', 'precision', 'recall', 'f1_score', 'mcc', - 'rmse', 'mae', 'mse', 'r2', 'max'] + 'rmse', 'mae', 'mse', 'r2', 'max', 'spearman', 'kendall', 'pearson'] def p2v(y: List[float], y_pred: List[float]): @@ -51,5 +52,11 @@ def eval_metric_func(y: List[float], y_pred: List[float], metric: Metric) -> flo return np.sqrt(eval_metric_func(y, y_pred, 'mse')) elif metric == 'max': return np.max(abs(y - y_pred)) + elif metric == 'spearman': + return scipy.stats.spearmanr(y, y_pred)[0] + elif metric == 'kendall': + return scipy.stats.kendalltau(y, y_pred)[0] + elif metric == 'pearson': + return scipy.stats.pearsonr(y, y_pred)[0] else: raise RuntimeError(f'Unsupported metrics {metric}') diff --git a/test/cross_validation/test_cv_pure.py b/test/cross_validation/test_cv_pure.py index ffa1015..bb114ac 100644 --- a/test/cross_validation/test_cv_pure.py +++ b/test/cross_validation/test_cv_pure.py @@ -37,7 +37,7 @@ def test_only_graph_classification(mgk_file, model, split_type): dataset=dataset, model=model, task_type='binary', - metrics=['roc-auc', 'mcc'], + metrics=['roc-auc', 'accuracy', 'precision', 'recall', 'f1_score', 'mcc'], split_type=split_type, split_sizes=[0.75, 0.25], num_folds=2, @@ -75,7 +75,7 @@ def test_only_graph_scalable_gps(mgk_file, modelsets, split_type): dataset=dataset, model=model, task_type='regression', - metrics=['rmse', 'mae', 'r2'], + metrics=['rmse', 'mae', 'mse', 'r2', 'max', 'spearman', 'kendall', 'pearson'], split_type=split_type, split_sizes=[0.75, 0.25], num_folds=2,