Skip to content

Commit

Permalink
Update to 1.0.0:
Browse files Browse the repository at this point in the history
change the dataset.y to be 2-dimensional numpy array for single-target data set.
  • Loading branch information
Xiangyan93 committed Dec 13, 2023
1 parent a9b9fd1 commit 0996785
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion mgktools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# -*- coding: utf-8 -*-


__version__ = '0.2.1'
__version__ = '1.0.0'
5 changes: 1 addition & 4 deletions mgktools/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,7 @@ def X(self) -> np.ndarray:
@property
def y(self):
y = concatenate([d.targets for d in self.data], axis=0)
if y is not None and y.shape[1] == 1:
return y.ravel()
else:
return y
return y

@property
def repr(self) -> np.ndarray: # 2d array str.
Expand Down
6 changes: 6 additions & 0 deletions mgktools/evaluators/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,13 @@ def evaluate_train_test(self, dataset_train: Dataset,

X_train = dataset_train.X
y_train = dataset_train.y
if y_train.shape[1] == 1:
y_train = y_train.ravel()
repr_train = dataset_train.repr.ravel()
X_test = dataset_test.X
y_test = dataset_test.y
if y_test.shape[1] == 1:
y_test = y_test.ravel()
repr_test = dataset_test.repr.ravel()
# Find the most similar sample in training sets.
if self.n_similar is None:
Expand All @@ -187,6 +191,8 @@ def evaluate_train_test(self, dataset_train: Dataset,

def _evaluate_loocv(self):
X, y, repr = self.dataset.X, self.dataset.y, self.dataset.repr.ravel()
if y.shape[1] == 1:
y = y.ravel()
if self.n_similar is not None:
y_similar = self.get_similar_info(X, X, repr, self.n_similar)
else:
Expand Down
6 changes: 3 additions & 3 deletions test/data/test_data_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@pytest.mark.parametrize('testset', [
(['targets_1'], (4,)),
(['targets_1'], (4, 1)),
(['targets_1', 'targets_2'], (4, 2)),
])
def test_only_graph(testset):
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_only_fingerprints(testset):
features_generator=[features_generator],
features_combination=features_combination)
assert dataset.X.shape == (4, n_features)
assert dataset.y.shape == (4,)
assert dataset.y.shape == (4, 1)


@pytest.mark.parametrize('testset', [
Expand All @@ -66,4 +66,4 @@ def test_graph_fingerprints(testset):
features_combination=features_combination)
dataset.graph_kernel_type = 'graph'
assert dataset.X.shape == (4, 1 + n_features)
assert dataset.y.shape == (4,)
assert dataset.y.shape == (4, 1)
6 changes: 3 additions & 3 deletions test/data/test_data_pure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@pytest.mark.parametrize('testset', [
(['targets_1'], (4,)),
(['targets_1'], (4, 1)),
(['targets_1', 'targets_2'], (4, 2)),
])
def test_only_graph(testset):
Expand All @@ -38,7 +38,7 @@ def test_only_fingerprints(testset):
target_columns=['targets_1'],
features_generator=[features_generator])
assert dataset.X.shape == (4, n_features)
assert dataset.y.shape == (4,)
assert dataset.y.shape == (4, 1)


@pytest.mark.parametrize('testset', [
Expand All @@ -55,4 +55,4 @@ def test_graph_fingerprints(testset):
features_generator=[features_generator])
dataset.graph_kernel_type = 'graph'
assert dataset.X.shape == (4, 1 + n_features)
assert dataset.y.shape == (4,)
assert dataset.y.shape == (4, 1)
2 changes: 1 addition & 1 deletion test/interpret/test_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_interpret_training_mols(testset):
mgk_hyperparameters_file=mgk_hyperparameters_file,
n_jobs=6)
for i, df in enumerate(df_interpret):
assert df['contribution_value'].sum() == pytest.approx(y_pred[i], 1e-5)
assert df['contribution_value'].sum() == pytest.approx(y_pred[i], 1e-4)


@pytest.mark.parametrize('testset', [
Expand Down

0 comments on commit 0996785

Please sign in to comment.