From 099678535f4e0f5ba72cc725ea3696a2ef45c855 Mon Sep 17 00:00:00 2001 From: xiangyan93 Date: Wed, 13 Dec 2023 10:10:49 -0500 Subject: [PATCH] Update to 1.0.0: change the dataset.y to be 2-dimensional numpy array for single-target data set. --- mgktools/__init__.py | 2 +- mgktools/data/data.py | 5 +---- mgktools/evaluators/cross_validation.py | 6 ++++++ test/data/test_data_mixture.py | 6 +++--- test/data/test_data_pure.py | 6 +++--- test/interpret/test_interpret.py | 2 +- 6 files changed, 15 insertions(+), 12 deletions(-) diff --git a/mgktools/__init__.py b/mgktools/__init__.py index f22e5b7..24de858 100644 --- a/mgktools/__init__.py +++ b/mgktools/__init__.py @@ -2,4 +2,4 @@ # -*- coding: utf-8 -*- -__version__ = '0.2.1' +__version__ = '1.0.0' diff --git a/mgktools/data/data.py b/mgktools/data/data.py index d9206f0..5c01dbf 100644 --- a/mgktools/data/data.py +++ b/mgktools/data/data.py @@ -405,10 +405,7 @@ def X(self) -> np.ndarray: @property def y(self): y = concatenate([d.targets for d in self.data], axis=0) - if y is not None and y.shape[1] == 1: - return y.ravel() - else: - return y + return y @property def repr(self) -> np.ndarray: # 2d array str. diff --git a/mgktools/evaluators/cross_validation.py b/mgktools/evaluators/cross_validation.py index 497c35e..71dacba 100644 --- a/mgktools/evaluators/cross_validation.py +++ b/mgktools/evaluators/cross_validation.py @@ -159,9 +159,13 @@ def evaluate_train_test(self, dataset_train: Dataset, X_train = dataset_train.X y_train = dataset_train.y + if y_train.shape[1] == 1: + y_train = y_train.ravel() repr_train = dataset_train.repr.ravel() X_test = dataset_test.X y_test = dataset_test.y + if y_test.shape[1] == 1: + y_test = y_test.ravel() repr_test = dataset_test.repr.ravel() # Find the most similar sample in training sets. if self.n_similar is None: @@ -187,6 +191,8 @@ def evaluate_train_test(self, dataset_train: Dataset, def _evaluate_loocv(self): X, y, repr = self.dataset.X, self.dataset.y, self.dataset.repr.ravel() + if y.shape[1] == 1: + y = y.ravel() if self.n_similar is not None: y_similar = self.get_similar_info(X, X, repr, self.n_similar) else: diff --git a/test/data/test_data_mixture.py b/test/data/test_data_mixture.py index a0053a3..ee92581 100644 --- a/test/data/test_data_mixture.py +++ b/test/data/test_data_mixture.py @@ -17,7 +17,7 @@ @pytest.mark.parametrize('testset', [ - (['targets_1'], (4,)), + (['targets_1'], (4, 1)), (['targets_1', 'targets_2'], (4, 2)), ]) def test_only_graph(testset): @@ -46,7 +46,7 @@ def test_only_fingerprints(testset): features_generator=[features_generator], features_combination=features_combination) assert dataset.X.shape == (4, n_features) - assert dataset.y.shape == (4,) + assert dataset.y.shape == (4, 1) @pytest.mark.parametrize('testset', [ @@ -66,4 +66,4 @@ def test_graph_fingerprints(testset): features_combination=features_combination) dataset.graph_kernel_type = 'graph' assert dataset.X.shape == (4, 1 + n_features) - assert dataset.y.shape == (4,) + assert dataset.y.shape == (4, 1) diff --git a/test/data/test_data_pure.py b/test/data/test_data_pure.py index d4dca85..336e8ed 100644 --- a/test/data/test_data_pure.py +++ b/test/data/test_data_pure.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize('testset', [ - (['targets_1'], (4,)), + (['targets_1'], (4, 1)), (['targets_1', 'targets_2'], (4, 2)), ]) def test_only_graph(testset): @@ -38,7 +38,7 @@ def test_only_fingerprints(testset): target_columns=['targets_1'], features_generator=[features_generator]) assert dataset.X.shape == (4, n_features) - assert dataset.y.shape == (4,) + assert dataset.y.shape == (4, 1) @pytest.mark.parametrize('testset', [ @@ -55,4 +55,4 @@ def test_graph_fingerprints(testset): features_generator=[features_generator]) dataset.graph_kernel_type = 'graph' assert dataset.X.shape == (4, 1 + n_features) - assert dataset.y.shape == (4,) + assert dataset.y.shape == (4, 1) diff --git a/test/interpret/test_interpret.py b/test/interpret/test_interpret.py index 5d8fd12..9842e1b 100644 --- a/test/interpret/test_interpret.py +++ b/test/interpret/test_interpret.py @@ -27,7 +27,7 @@ def test_interpret_training_mols(testset): mgk_hyperparameters_file=mgk_hyperparameters_file, n_jobs=6) for i, df in enumerate(df_interpret): - assert df['contribution_value'].sum() == pytest.approx(y_pred[i], 1e-5) + assert df['contribution_value'].sum() == pytest.approx(y_pred[i], 1e-4) @pytest.mark.parametrize('testset', [