Skip to content

Commit 9d9512d

Browse files
committed
bmti: add faster cg-based solver, remove old solvers
1 parent b26f625 commit 9d9512d

File tree

1 file changed

+8
-19
lines changed

1 file changed

+8
-19
lines changed

dadapy/density_advanced.py

+8-19
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,6 @@ def compute_density_BMTI(
362362
self,
363363
delta_F_inv_cov="uncorr",
364364
comp_log_den_err=False,
365-
mem_efficient=False,
366365
alpha=1,
367366
log_den=None,
368367
log_den_err=None,
@@ -384,8 +383,6 @@ def compute_density_BMTI(
384383
finding the approximate diagonal inverse which multiplied by C gives the least-squares closest
385384
matrix to the identity in the Frobenius norm
386385
comp_log_den_err (bool): if True, compute the error on the BMTI estimates. Can be highly time consuming
387-
mem_efficient (bool): if True, use a sparse matrice to solve BMTI linear system (slower). If False, use a
388-
dense NxN matrix; this is faster, but can require a great amount of memory if the system is large.
389386
alpha (float): can take values from 0.0 to 1.0. Indicates the portion of BMTI in the sum of the likelihoods
390387
alpha*L_BMTI + (1-alpha)*L_kstarNN. Setting alpha=1.0 corresponds to not reguarising BMTI.
391388
log_den (np.ndarray(float)): size N. The array of the log-densities of the regulariser.
@@ -415,13 +412,6 @@ def compute_density_BMTI(
415412
self.log_den = log_den
416413
self.log_den_err = log_den_err
417414

418-
# add a warnings.warning if self.N > 10000 and mem_efficient is False
419-
if self.N > 15000 and mem_efficient is False:
420-
warnings.warn(
421-
"The number of points is large and the memory efficient option is not selected. \
422-
If you run into memory issues, consider using the slower memory efficient option."
423-
)
424-
425415
if self.verb:
426416
print("BMTI density estimation started")
427417
sec = time.time()
@@ -432,14 +422,14 @@ def compute_density_BMTI(
432422
sec2 = time.time()
433423

434424
if self.verb:
435-
print("{0:0.2f} seconds to fill get linear system ready".format(sec2 - sec))
425+
print("{0:0.2f} seconds to get the linear system ready".format(sec2 - sec))
436426

437427
# solve linear system
438-
log_den = self._solve_BMTI_reg_linar_system(A, deltaFcum, mem_efficient)
428+
log_den = self._solve_BMTI_reg_linar_system(A, deltaFcum)
439429
self.log_den = log_den
440430

441431
if self.verb:
442-
print("{0:0.2f} seconds to solve linear system".format(time.time() - sec2))
432+
print("{0:0.2f} seconds to solve the linear system".format(time.time() - sec2))
443433
sec2 = time.time()
444434

445435
# compute error
@@ -529,10 +519,9 @@ def _get_BMTI_reg_linear_system(self, delta_F_inv_cov, alpha):
529519

530520
return A, deltaFcum
531521

532-
def _solve_BMTI_reg_linar_system(self, A, deltaFcum, mem_efficient):
533-
if mem_efficient is False:
534-
log_den = np.linalg.solve(A.todense(), deltaFcum)
535-
else:
536-
log_den = sparse.linalg.spsolve(A.tocsr(), deltaFcum)
537-
522+
def _solve_BMTI_reg_linar_system(self, A, deltaFcum):
523+
524+
# log_den = np.linalg.solve(A.todense(), deltaFcum)
525+
log_den = sparse.linalg.cg(A, deltaFcum)[0]
526+
538527
return log_den

0 commit comments

Comments
 (0)