Skip to content

Commit

Permalink
Update to 1.1.0:
Browse files Browse the repository at this point in the history
add 3 metrics for regression tasks: spearman's r, pearson's r, and kendall's tau.
  • Loading branch information
Xiangyan93 committed Dec 14, 2023
1 parent 0996785 commit 47ecfa2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mgktools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# -*- coding: utf-8 -*-


__version__ = '1.0.0'
__version__ = '1.1.0'
9 changes: 8 additions & 1 deletion mgktools/evaluators/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from typing import Dict, Iterator, List, Optional, Union, Literal, Tuple
import numpy as np
import scipy
from sklearn.metrics import (
mean_squared_error,
mean_absolute_error,
Expand All @@ -14,7 +15,7 @@
matthews_corrcoef
)
Metric = Literal['roc-auc', 'accuracy', 'precision', 'recall', 'f1_score', 'mcc',
'rmse', 'mae', 'mse', 'r2', 'max']
'rmse', 'mae', 'mse', 'r2', 'max', 'spearman', 'kendall', 'pearson']


def p2v(y: List[float], y_pred: List[float]):
Expand Down Expand Up @@ -51,5 +52,11 @@ def eval_metric_func(y: List[float], y_pred: List[float], metric: Metric) -> flo
return np.sqrt(eval_metric_func(y, y_pred, 'mse'))
elif metric == 'max':
return np.max(abs(y - y_pred))
elif metric == 'spearman':
return scipy.stats.spearmanr(y, y_pred)[0]
elif metric == 'kendall':
return scipy.stats.kendalltau(y, y_pred)[0]
elif metric == 'pearson':
return scipy.stats.pearsonr(y, y_pred)[0]
else:
raise RuntimeError(f'Unsupported metrics {metric}')
4 changes: 2 additions & 2 deletions test/cross_validation/test_cv_pure.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_only_graph_classification(mgk_file, model, split_type):
dataset=dataset,
model=model,
task_type='binary',
metrics=['roc-auc', 'mcc'],
metrics=['roc-auc', 'accuracy', 'precision', 'recall', 'f1_score', 'mcc'],
split_type=split_type,
split_sizes=[0.75, 0.25],
num_folds=2,
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_only_graph_scalable_gps(mgk_file, modelsets, split_type):
dataset=dataset,
model=model,
task_type='regression',
metrics=['rmse', 'mae', 'r2'],
metrics=['rmse', 'mae', 'mse', 'r2', 'max', 'spearman', 'kendall', 'pearson'],
split_type=split_type,
split_sizes=[0.75, 0.25],
num_folds=2,
Expand Down

0 comments on commit 47ecfa2

Please sign in to comment.