Skip to content

Commit 5858e68

Browse files
committed
bmti: add faster cg-based solver, remove old solvers
1 parent b26f625 commit 5858e68

File tree

2 files changed

+8
-20
lines changed

2 files changed

+8
-20
lines changed

dadapy/density_advanced.py

+7-19
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,6 @@ def compute_density_BMTI(
362362
self,
363363
delta_F_inv_cov="uncorr",
364364
comp_log_den_err=False,
365-
mem_efficient=False,
366365
alpha=1,
367366
log_den=None,
368367
log_den_err=None,
@@ -384,8 +383,6 @@ def compute_density_BMTI(
384383
finding the approximate diagonal inverse which multiplied by C gives the least-squares closest
385384
matrix to the identity in the Frobenius norm
386385
comp_log_den_err (bool): if True, compute the error on the BMTI estimates. Can be highly time consuming
387-
mem_efficient (bool): if True, use a sparse matrice to solve BMTI linear system (slower). If False, use a
388-
dense NxN matrix; this is faster, but can require a great amount of memory if the system is large.
389386
alpha (float): can take values from 0.0 to 1.0. Indicates the portion of BMTI in the sum of the likelihoods
390387
alpha*L_BMTI + (1-alpha)*L_kstarNN. Setting alpha=1.0 corresponds to not reguarising BMTI.
391388
log_den (np.ndarray(float)): size N. The array of the log-densities of the regulariser.
@@ -415,13 +412,6 @@ def compute_density_BMTI(
415412
self.log_den = log_den
416413
self.log_den_err = log_den_err
417414

418-
# add a warnings.warning if self.N > 10000 and mem_efficient is False
419-
if self.N > 15000 and mem_efficient is False:
420-
warnings.warn(
421-
"The number of points is large and the memory efficient option is not selected. \
422-
If you run into memory issues, consider using the slower memory efficient option."
423-
)
424-
425415
if self.verb:
426416
print("BMTI density estimation started")
427417
sec = time.time()
@@ -432,14 +422,16 @@ def compute_density_BMTI(
432422
sec2 = time.time()
433423

434424
if self.verb:
435-
print("{0:0.2f} seconds to fill get linear system ready".format(sec2 - sec))
425+
print("{0:0.2f} seconds to get the linear system ready".format(sec2 - sec))
436426

437427
# solve linear system
438-
log_den = self._solve_BMTI_reg_linar_system(A, deltaFcum, mem_efficient)
428+
log_den = self._solve_BMTI_reg_linar_system(A, deltaFcum)
439429
self.log_den = log_den
440430

441431
if self.verb:
442-
print("{0:0.2f} seconds to solve linear system".format(time.time() - sec2))
432+
print(
433+
"{0:0.2f} seconds to solve the linear system".format(time.time() - sec2)
434+
)
443435
sec2 = time.time()
444436

445437
# compute error
@@ -529,10 +521,6 @@ def _get_BMTI_reg_linear_system(self, delta_F_inv_cov, alpha):
529521

530522
return A, deltaFcum
531523

532-
def _solve_BMTI_reg_linar_system(self, A, deltaFcum, mem_efficient):
533-
if mem_efficient is False:
534-
log_den = np.linalg.solve(A.todense(), deltaFcum)
535-
else:
536-
log_den = sparse.linalg.spsolve(A.tocsr(), deltaFcum)
537-
524+
def _solve_BMTI_reg_linar_system(self, A, deltaFcum):
525+
log_den = sparse.linalg.cg(A, deltaFcum, atol=0.0, maxiter=None)[0]
538526
return log_den

tests/test_density_advanced/test_density_advanced.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,4 @@ def test_density_BMTI():
133133
da.set_id(2)
134134
da.compute_density_BMTI(alpha=0.99)
135135

136-
assert np.allclose(da.log_den, expected_density_BMTI)
136+
assert np.allclose(da.log_den, expected_density_BMTI, rtol=1e-05, atol=1e-01)

0 commit comments

Comments
 (0)