Skip to content

Commit c2a1a96

Browse files
committed
BMTI: combine standard and reg into a single function
1 parent 2b7cd1e commit c2a1a96

File tree

2 files changed

+12
-70
lines changed

2 files changed

+12
-70
lines changed

dadapy/density_advanced.py

+9-35
Original file line numberDiff line numberDiff line change
@@ -273,47 +273,17 @@ def compute_density_BMTI(
273273
delta_F_inv_cov="uncorr",
274274
comp_log_den_err=False,
275275
mem_efficient=False,
276-
):
277-
"""Compute the log-density for each point using BMTI
278-
279-
Args:
280-
delta_F_inv_cov (str): see compute_density_BMTI_reg docs.
281-
comp_log_den_err (bool): see compute_density_BMTI_reg docs.
282-
mem_efficient (bool): see compute_density_BMTI_reg docs.
283-
284-
"""
285-
286-
# call compute_density_BMTI_reg with alpha=1 and log_den and log_den_err as arrays of ones
287-
self.compute_density_BMTI_reg(
288-
alpha=1.0,
289-
log_den=np.ones(self.N),
290-
log_den_err=np.ones(self.N),
291-
delta_F_inv_cov=delta_F_inv_cov,
292-
comp_log_den_err=comp_log_den_err,
293-
mem_efficient=mem_efficient,
294-
)
295-
296-
# ----------------------------------------------------------------------------------------------
297-
298-
def compute_density_BMTI_reg(
299-
self,
300-
alpha=0.1,
276+
alpha=1,
301277
log_den=None,
302278
log_den_err=None,
303-
delta_F_inv_cov="uncorr",
304-
comp_log_den_err=False,
305-
mem_efficient=False,
306279
):
307-
"""Compute the log-density for each point using BMTI plus kstarNN estimator as a regulariser.
280+
"""Compute the log-density for each point using BMTI.
308281
309-
The regulariser log-density and its errors can be passed as arguments: log_den and log_den_err. If any of these
310-
two is not specified, use kstarNN estimator as a regulariser.
282+
If alpha<1, the algorithm also includes a regularisatin. The regulariser log-density and its errors can be
283+
passed as arguments: log_den and log_den_err. If any of these two is not specified, use kstarNN estimator
284+
as a regulariser.
311285
312286
Args:
313-
alpha (float): can take values from 0.0 to 1.0. Indicates the portion of BMTI in the sum of the likelihoods
314-
alpha*L_BMTI + (1-alpha)*L_kstarNN. Setting alpha=1.0 corresponds to not reguarising BMTI.
315-
log_den (np.ndarray(float)): size N. The array of the log-densities of the regulariser.
316-
log_den_err (np.ndarray(float)): size N. The array of the log-density errors of the regulariser.
317287
delta_F_inv_cov (str): specify the method used to invert the cross-covariance matrix C of the log-density
318288
deviations cov[deltaF_ij,deltaF_kl]. Currently implemented methods:
319289
"uncorr" (default): all the deltaFs are assumed uncorrelated, i.e. C is assumed to be diagonal with
@@ -326,6 +296,10 @@ def compute_density_BMTI_reg(
326296
comp_log_den_err (bool): if True, compute the error on the BMTI estimates. Can be highly time consuming
327297
mem_efficient (bool): if True, use a sparse matrice to solve BMTI linear system (slower). If False, use a
328298
dense NxN matrix; this is faster, but can require a great amount of memory if the system is large.
299+
alpha (float): can take values from 0.0 to 1.0. Indicates the portion of BMTI in the sum of the likelihoods
300+
alpha*L_BMTI + (1-alpha)*L_kstarNN. Setting alpha=1.0 corresponds to not reguarising BMTI.
301+
log_den (np.ndarray(float)): size N. The array of the log-densities of the regulariser.
302+
log_den_err (np.ndarray(float)): size N. The array of the log-density errors of the regulariser.
329303
330304
"""
331305

tests/test_density_advanced/test_density_advanced.py

+3-35
Original file line numberDiff line numberDiff line change
@@ -75,39 +75,7 @@ def test_compute_deltaFs():
7575
assert np.allclose(da.Fij_var_array, expected_Fij_var_array)
7676

7777

78-
# define the expected density
7978
expected_density_BMTI = np.array(
80-
[
81-
0.012505854084320398,
82-
-1.4989243919120265,
83-
-0.8325855985351576,
84-
-1.8954811470419732,
85-
-0.08608518234399808,
86-
-1.377358160570784,
87-
-2.2853275320451556,
88-
-0.08077062180341209,
89-
0.03151493142829422,
90-
-2.295060446120319,
91-
-0.485534023025263,
92-
-1.5291208381769597,
93-
-2.0291222925304333,
94-
-2.507439558393103,
95-
0.05236125958005627,
96-
-0.6844157822716908,
97-
-0.205568978673708,
98-
-1.3777138853748458,
99-
-1.2926910126536086,
100-
-1.1630749466695476,
101-
-1.9641366761139865,
102-
-1.421685853814561,
103-
-0.4840608241935639,
104-
-0.9553572813490178,
105-
-0.8380943495955488,
106-
]
107-
)
108-
109-
110-
expected_density_BMTI_reg = np.array(
11179
[
11280
-1.698780556695925,
11381
-3.189031310462691,
@@ -138,7 +106,7 @@ def test_compute_deltaFs():
138106
)
139107

140108

141-
def test_density_BMTI_reg():
109+
def test_density_BMTI():
142110
"""Test the density_BMTI method."""
143111
filename = os.path.join(os.path.split(__file__)[0], "../2gaussians_in_2d.npy")
144112

@@ -147,6 +115,6 @@ def test_density_BMTI_reg():
147115
da = DensityAdvanced(coordinates=X, maxk=10, verbose=True)
148116
da.compute_distances()
149117
da.set_id(2)
150-
da.compute_density_BMTI_reg(alpha=0.99)
118+
da.compute_density_BMTI(alpha=0.99)
151119

152-
assert np.allclose(da.log_den, expected_density_BMTI_reg)
120+
assert np.allclose(da.log_den, expected_density_BMTI)

0 commit comments

Comments
 (0)