Skip to content

Commit 407edef

Browse files
committed
BMTI: fix CI error on test
1 parent 65571e7 commit 407edef

File tree

2 files changed

+82
-25
lines changed

2 files changed

+82
-25
lines changed

tests/test_density_advanced/test_density_advanced.py

+74-17
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
"""Module for testing the DensityAdvanced class."""
1717

18+
import os
19+
1820
import numpy as np
1921

2022
from dadapy import DensityAdvanced
@@ -73,23 +75,78 @@ def test_compute_deltaFs():
7375
assert np.allclose(da.Fij_var_array, expected_Fij_var_array)
7476

7577

76-
def test_density_BMTI():
78+
# define the expected density
79+
expected_density_BMTI = np.array(
80+
[
81+
0.012505854084320398,
82+
-1.4989243919120265,
83+
-0.8325855985351576,
84+
-1.8954811470419732,
85+
-0.08608518234399808,
86+
-1.377358160570784,
87+
-2.2853275320451556,
88+
-0.08077062180341209,
89+
0.03151493142829422,
90+
-2.295060446120319,
91+
-0.485534023025263,
92+
-1.5291208381769597,
93+
-2.0291222925304333,
94+
-2.507439558393103,
95+
0.05236125958005627,
96+
-0.6844157822716908,
97+
-0.205568978673708,
98+
-1.3777138853748458,
99+
-1.2926910126536086,
100+
-1.1630749466695476,
101+
-1.9641366761139865,
102+
-1.421685853814561,
103+
-0.4840608241935639,
104+
-0.9553572813490178,
105+
-0.8380943495955488,
106+
]
107+
)
108+
109+
110+
expected_density_BMTI_reg = np.array(
111+
[
112+
-1.698780556695925,
113+
-3.189031310462691,
114+
-2.523230763956809,
115+
-3.583852700030218,
116+
-1.7960043733709044,
117+
-3.0643213049156897,
118+
-3.977860544167944,
119+
-1.7921150971140554,
120+
-1.6797609470985766,
121+
-3.988603310090793,
122+
-2.188728113912356,
123+
-3.224052459516409,
124+
-3.721582556367274,
125+
-4.199926673166814,
126+
-1.658836709713782,
127+
-2.3769581305103724,
128+
-1.9130593343411615,
129+
-3.0670805680617694,
130+
-2.9879430640544475,
131+
-2.8543162418664916,
132+
-3.6531870983140053,
133+
-3.113710498689125,
134+
-2.192481141955024,
135+
-2.6462753213763226,
136+
-2.5283426346886047,
137+
]
138+
)
139+
140+
141+
def test_density_BMTI_reg():
77142
"""Test the density_BMTI method."""
78-
# define the expected density
79-
expected_density = np.array(
80-
[
81-
-0.06290097151904936,
82-
-0.023556982206034104,
83-
-0.011088481060296614,
84-
-0.004762206359420831,
85-
-0.02968801181576885,
86-
-0.05593299286955579,
87-
]
88-
)
143+
filename = os.path.join(os.path.split(__file__)[0], "../2gaussians_in_2d.npy")
89144

90-
da = DensityAdvanced(coordinates=data, maxk=3, verbose=True)
145+
X = np.load(filename)[:25]
146+
147+
da = DensityAdvanced(coordinates=X, maxk=10, verbose=True)
91148
da.compute_distances()
92-
da.set_id(1)
93-
da.set_kstar(4)
94-
da.compute_density_BMTI()
95-
assert np.allclose(da.log_den, expected_density)
149+
da.set_id(2)
150+
da.compute_density_BMTI_reg(alpha=0.99)
151+
152+
assert np.allclose(da.log_den, expected_density_BMTI_reg)

tests/test_neigh_graph.py/test_neigh_graph.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dadapy import NeighGraph
2121

2222
# define a basic dataset with 6 points
23-
data = np.array([[0, 0], [0.15, 0], [0.2, 0], [4, 0], [4.1, 0], [4.2, 0]])
23+
data = np.array([[0, 0], [0.15, 0], [0.2, 0], [4, 0], [4.09, 0], [4.2, 0]])
2424

2525
# list of neighbor pairs
2626
expected_nint_list = [[0, 1], [1, 2], [2, 1], [3, 4], [4, 3], [5, 4]]
@@ -31,24 +31,24 @@
3131
# number of neighbour pairs
3232
expected_nspar = 6
3333

34-
expected_neigh_dists = np.array([0.15, 0.05, 0.05, 0.1, 0.1, 0.1])
34+
expected_neigh_dists = np.array([0.15, 0.05, 0.05, 0.09, 0.09, 0.11])
3535

3636
expected_distance_graph = [
3737
[0.0, 0.15, 0.0, 0.0, 0.0, 0.0],
3838
[0.0, 0.0, 0.05, 0.0, 0.0, 0.0],
3939
[0.0, 0.05, 0.0, 0.0, 0.0, 0.0],
40-
[0.0, 0.0, 0.0, 0.0, 0.1, 0.0],
41-
[0.0, 0.0, 0.0, 0.1, 0.0, 0.0],
42-
[0.0, 0.0, 0.0, 0.0, 0.1, 0.0],
40+
[0.0, 0.0, 0.0, 0.0, 0.09, 0.0],
41+
[0.0, 0.0, 0.0, 0.09, 0.0, 0.0],
42+
[0.0, 0.0, 0.0, 0.0, 0.11, 0.0],
4343
]
4444

4545
neigh_vector_diffs = [
4646
[0.15, 0.0],
4747
[0.05, 0.0],
4848
[-0.05, 0.0],
49-
[0.1, 0.0],
50-
[-0.1, 0.0],
51-
[-0.1, 0.0],
49+
[0.09, 0.0],
50+
[-0.09, 0.0],
51+
[-0.11, 0.0],
5252
]
5353

5454

0 commit comments

Comments
 (0)