From 2d93278e2861f223454b083c0242f333994ce8d1 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 5 Jun 2023 16:46:30 +0200 Subject: [PATCH 01/12] add tqdm --- src/moscot/base/output.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 9cf53bbaa..c35582eda 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -6,6 +6,7 @@ import numpy as np import scipy.sparse as sp from scipy.sparse.linalg import LinearOperator +from tqdm import tqdm from moscot._docs._docs import d from moscot._logging import logger @@ -244,7 +245,7 @@ def sparsify( k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack) tmaps_sparse: List[sp.csr_matrix] = [] - for batch in range(0, k, batch_size): + for batch in tqdm(range(0, k, batch_size)): x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) res = np.array(func(x, scale_by_marginals=False)) res[res < thr] = 0.0 From 4b9311b865df551661d23df0788e4ea7cc1224a4 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 6 Jun 2023 17:49:11 +0200 Subject: [PATCH 02/12] update notebooks --- docs/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/notebooks b/docs/notebooks index 1ef4a258a..449c9fc11 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 1ef4a258a88b8ad56d0831738610a5891f1ba2f5 +Subproject commit 449c9fc11dd62911f787cabf1a413bbc8297706f From c0ac29366a6770e240ceaba4d5dee29077291d75 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 7 Jun 2023 16:09:08 +0200 Subject: [PATCH 03/12] initial push --- src/moscot/base/output.py | 57 +++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index c35582eda..bbf45dd23 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -1,12 +1,14 @@ from abc import ABC, abstractmethod from copy import copy from functools import partial -from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple + +from joblib import Parallel, delayed +from tqdm import tqdm import numpy as np import scipy.sparse as sp from scipy.sparse.linalg import LinearOperator -from tqdm import tqdm from moscot._docs._docs import d from moscot._logging import logger @@ -178,6 +180,7 @@ def sparsify( batch_size: int = 1024, n_samples: Optional[int] = None, seed: Optional[int] = None, + n_jobs: int = -1, ) -> "MatrixSolverOutput": """Sparsify the :attr:`transport_matrix`. @@ -213,12 +216,28 @@ def sparsify( ``batch_size``. seed Random seed needed for sampling if ``mode = 'percentile'``. + n_jobs + TODO Returns ------- Solve output with a sparsified transport matrix. """ n, m = self.shape + + def _min_row(batch: int) -> float: + x = np.eye(m, min(batch_size, m - batch), -(min(batch, m))) + res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors + return float(res.max(axis=1).min()) + + def _min_row_with_thr( + batch: int, threshold: float, k: int, func: Callable[[np.ndarray, bool], np.ndarray] + ) -> Dict[int, sp.csr_matrix]: + x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) + res = np.array(func(x, scale_by_marginals=False)) + res[res < threshold] = 0.0 + return sp.csr_matrix(res.T if n < m else res) + if mode == "threshold": if value is None: raise ValueError("If `mode = 'threshold'`, `threshold` cannot be `None`.") @@ -235,23 +254,33 @@ def sparsify( res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors thr = np.percentile(res, value) elif mode == "min_row": - thr = np.inf - for batch in range(0, m, batch_size): - x = np.eye(m, min(batch_size, m - batch), -(min(batch, m))) - res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors - thr = min(thr, float(res.max(axis=1).min())) + # thr = np.inf + # for batch in tqdm(range(0, m, batch_size)): + # x = np.eye(m, min(batch_size, m - batch), -(min(batch, m))) + # res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors + # thr = min(thr, float(res.max(axis=1).min())) + + results = Parallel(n_jobs=n_jobs, verbose=3)( + delayed(_min_row)(batch) for batch in tqdm(range(0, m, batch_size)) + ) + thr = np.min(results) + else: raise NotImplementedError(mode) k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack) - tmaps_sparse: List[sp.csr_matrix] = [] - for batch in tqdm(range(0, k, batch_size)): - x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) - res = np.array(func(x, scale_by_marginals=False)) - res[res < thr] = 0.0 - tmaps_sparse.append(sp.csr_matrix(res.T if n < m else res)) + # tmaps_sparse: List[sp.csr_matrix] = [] + # for batch in tqdm(range(0, k, batch_size)): + # x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) + # res = np.array(func(x, scale_by_marginals=False)) + # res[res < thr] = 0.0 + # tmaps_sparse.append(sp.csr_matrix(res.T if n < m else res)) + + results = Parallel(n_jobs=n_jobs, verbose=3)( + delayed(_min_row_with_thr)(batch, thr, k, func) for batch in tqdm(range(0, k, batch_size)) + ) return MatrixSolverOutput( - transport_matrix=fn_stack(tmaps_sparse), cost=self.cost, converged=self.converged, is_linear=self.is_linear + transport_matrix=fn_stack(results), cost=self.cost, converged=self.converged, is_linear=self.is_linear ) @property From 3aa62449ac60ec29cec99d95438d277e4e524ad6 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 8 Jun 2023 10:29:27 +0200 Subject: [PATCH 04/12] Add subset method --- src/moscot/backends/ott/output.py | 59 +++++++++++++++++++++++-------- src/moscot/base/output.py | 13 +++++-- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index 83166e240..211a1a8db 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -153,20 +153,6 @@ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike: return self._output.apply(x, axis=1 - forward) return self._output.apply(x.T, axis=1 - forward).T # convert to batch first - @property - def shape(self) -> Tuple[int, int]: # noqa: D102 - if isinstance(self._output, OTTSinkhornOutput): - return self._output.f.shape[0], self._output.g.shape[0] - return self._output.geom.shape - - @property - def transport_matrix(self) -> ArrayLike: # noqa: D102 - return self._output.matrix - - @property - def is_linear(self) -> bool: # noqa: D102 - return isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput)) - def to(self, device: Optional[Device_t] = None) -> "OTTOutput": # noqa: D102 if isinstance(device, str) and ":" in device: device, ix = device.split(":") @@ -182,6 +168,51 @@ def to(self, device: Optional[Device_t] = None) -> "OTTOutput": # noqa: D102 return OTTOutput(jax.device_put(self._output, device)) + def subset( # noqa: D102 + self, src_ixs: Optional[ArrayLike] = None, tgt_ixs: Optional[ArrayLike] = None + ) -> "OTTOutput": + from ott.geometry import pointcloud + from ott.problems.linear import linear_problem + from ott.solvers.linear import sinkhorn + + assert self.is_linear, "Quadratic is not yet available." + assert not self.is_low_rank, "Low-rank is not yet available." + + f, g = self.potentials # type: ignore[misc] + prob = self._output.ot_prob + assert isinstance(prob.geom, pointcloud.PointCloud), "Only available for point clouds." + + x, y, a, b = prob.geom.x, prob.geom.y, prob.a, prob.b + eps = prob.epsilon + + if src_ixs is not None: + f, x, a = f[src_ixs], x[src_ixs], a[src_ixs] + if tgt_ixs is not None: + g, y, g = g[tgt_ixs], y[tgt_ixs], b[tgt_ixs] + + # TODO(michalk8): use flatten to pass other params (except batch_size) + geom = pointcloud.PointCloud(x, y, epsilon=eps) + prob = linear_problem.LinearProblem(geom, a=a, b=b, tau_a=prob.tau_a, tau_b=prob.tau_b) + out = sinkhorn.SinkhornOutput( + f=f, g=g, errors=self._output.errors, reg_ot_cost=self._output.reg_ot_cost, ot_prob=prob + ) + + return type(self)(out) + + @property + def shape(self) -> Tuple[int, int]: # noqa: D102 + if isinstance(self._output, OTTSinkhornOutput): + return self._output.f.shape[0], self._output.g.shape[0] + return self._output.geom.shape + + @property + def transport_matrix(self) -> ArrayLike: # noqa: D102 + return self._output.matrix + + @property + def is_linear(self) -> bool: # noqa: D102 + return isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput)) + @property def cost(self) -> float: # noqa: D102 if isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput)): diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index bbf45dd23..c64e16124 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -87,6 +87,10 @@ def is_low_rank(self) -> bool: def _ones(self, n: int) -> ArrayLike: pass + @abstractmethod + def subset(self, src_ixs: Optional[ArrayLike] = None, tgt_ixs: Optional[ArrayLike] = None) -> "BaseSolverOutput": + """TODO.""" + def push(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike: """Push mass through the :attr:`transport_matrix`. @@ -231,10 +235,10 @@ def _min_row(batch: int) -> float: return float(res.max(axis=1).min()) def _min_row_with_thr( - batch: int, threshold: float, k: int, func: Callable[[np.ndarray, bool], np.ndarray] + batch: int, threshold: float, k: int, func: Callable[[ArrayLike, bool], ArrayLike] ) -> Dict[int, sp.csr_matrix]: x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) - res = np.array(func(x, scale_by_marginals=False)) + res = np.array(func(x, scale_by_marginals=False)) # type: ignore[call-arg] res[res < threshold] = 0.0 return sp.csr_matrix(res.T if n < m else res) @@ -375,6 +379,11 @@ def to( # noqa: D102 obj._transport_matrix = obj.transport_matrix.astype(dtype) return obj + def subset( # noqa: D102 + self, src_ixs: Optional[ArrayLike] = None, tgt_ixs: Optional[ArrayLike] = None + ) -> "BaseSolverOutput": + raise NotImplementedError("Not yet implemented.") + @property def cost(self) -> float: # noqa: D102 return self._cost From ac67b753d429a863139c870c42bb15d6f6dd9514 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 8 Jun 2023 10:52:52 +0200 Subject: [PATCH 05/12] Fix not passing cost --- src/moscot/backends/ott/output.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index 211a1a8db..745ba4d95 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -182,7 +182,8 @@ def subset( # noqa: D102 prob = self._output.ot_prob assert isinstance(prob.geom, pointcloud.PointCloud), "Only available for point clouds." - x, y, a, b = prob.geom.x, prob.geom.y, prob.a, prob.b + (x, y, *_, cost_fn), aux_data = prob.geom.tree_flatten() + a, b = prob.a, prob.b eps = prob.epsilon if src_ixs is not None: @@ -190,8 +191,9 @@ def subset( # noqa: D102 if tgt_ixs is not None: g, y, g = g[tgt_ixs], y[tgt_ixs], b[tgt_ixs] - # TODO(michalk8): use flatten to pass other params (except batch_size) - geom = pointcloud.PointCloud(x, y, epsilon=eps) + _ = aux_data.pop("batch_size", None) + geom = pointcloud.PointCloud(x, y, epsilon=eps, cost_fn=cost_fn, **aux_data) + prob = linear_problem.LinearProblem(geom, a=a, b=b, tau_a=prob.tau_a, tau_b=prob.tau_b) out = sinkhorn.SinkhornOutput( f=f, g=g, errors=self._output.errors, reg_ot_cost=self._output.reg_ot_cost, ot_prob=prob From 5cac06c9591d8e709dda8c6376c65f4c2472034b Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:31:22 +0200 Subject: [PATCH 06/12] Interface new sparsification --- src/moscot/backends/ott/output.py | 2 +- src/moscot/base/output.py | 54 +++++++++++++++---------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index 745ba4d95..894dfeb33 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -169,7 +169,7 @@ def to(self, device: Optional[Device_t] = None) -> "OTTOutput": # noqa: D102 return OTTOutput(jax.device_put(self._output, device)) def subset( # noqa: D102 - self, src_ixs: Optional[ArrayLike] = None, tgt_ixs: Optional[ArrayLike] = None + self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None ) -> "OTTOutput": from ott.geometry import pointcloud from ott.problems.linear import linear_problem diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index c64e16124..55dad0bf0 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from copy import copy from functools import partial -from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple +from typing import Any, Callable, Iterable, Literal, Optional, Tuple, Union from joblib import Parallel, delayed from tqdm import tqdm @@ -88,7 +88,9 @@ def _ones(self, n: int) -> ArrayLike: pass @abstractmethod - def subset(self, src_ixs: Optional[ArrayLike] = None, tgt_ixs: Optional[ArrayLike] = None) -> "BaseSolverOutput": + def subset( # noqa: D102 + self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None + ) -> "BaseSolverOutput": """TODO.""" def push(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike: @@ -184,7 +186,7 @@ def sparsify( batch_size: int = 1024, n_samples: Optional[int] = None, seed: Optional[int] = None, - n_jobs: int = -1, + n_jobs: int = 1, ) -> "MatrixSolverOutput": """Sparsify the :attr:`transport_matrix`. @@ -227,21 +229,20 @@ def sparsify( ------- Solve output with a sparsified transport matrix. """ - n, m = self.shape def _min_row(batch: int) -> float: - x = np.eye(m, min(batch_size, m - batch), -(min(batch, m))) - res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors + res = self.subset(slice(batch, batch + batch_size)).transport_matrix return float(res.max(axis=1).min()) - def _min_row_with_thr( - batch: int, threshold: float, k: int, func: Callable[[ArrayLike, bool], ArrayLike] - ) -> Dict[int, sp.csr_matrix]: - x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) - res = np.array(func(x, scale_by_marginals=False)) # type: ignore[call-arg] + def _sparsify(batch: int, threshold: float) -> sp.csr_matrix: + res = self.subset(slice(batch, batch + batch_size)).transport_matrix + if not isinstance(res, np.ndarray): + res = np.array(res) + res[res < threshold] = 0.0 - return sp.csr_matrix(res.T if n < m else res) + return sp.csr_matrix(res) + n, m = self.shape if mode == "threshold": if value is None: raise ValueError("If `mode = 'threshold'`, `threshold` cannot be `None`.") @@ -249,7 +250,7 @@ def _min_row_with_thr( elif mode == "percentile": if value is None: raise ValueError("If `mode = 'percentile'`, `threshold` cannot be `None`.") - rng = np.random.RandomState(seed=seed) + rng = np.random.default_rng(seed=seed) n_samples = n_samples if n_samples is not None else batch_size k = min(n_samples, n) x = np.zeros((m, k)) @@ -265,26 +266,19 @@ def _min_row_with_thr( # thr = min(thr, float(res.max(axis=1).min())) results = Parallel(n_jobs=n_jobs, verbose=3)( - delayed(_min_row)(batch) for batch in tqdm(range(0, m, batch_size)) + delayed(_min_row)(batch) for batch in tqdm(range(0, n, batch_size)) ) thr = np.min(results) - else: raise NotImplementedError(mode) - k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack) - # tmaps_sparse: List[sp.csr_matrix] = [] - # for batch in tqdm(range(0, k, batch_size)): - # x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) - # res = np.array(func(x, scale_by_marginals=False)) - # res[res < thr] = 0.0 - # tmaps_sparse.append(sp.csr_matrix(res.T if n < m else res)) - + n, _ = self.shape results = Parallel(n_jobs=n_jobs, verbose=3)( - delayed(_min_row_with_thr)(batch, thr, k, func) for batch in tqdm(range(0, k, batch_size)) + delayed(_sparsify)(batch, thr) for batch in tqdm(range(0, n, batch_size)) ) + return MatrixSolverOutput( - transport_matrix=fn_stack(results), cost=self.cost, converged=self.converged, is_linear=self.is_linear + transport_matrix=sp.vstack(results), cost=self.cost, converged=self.converged, is_linear=self.is_linear ) @property @@ -380,9 +374,15 @@ def to( # noqa: D102 return obj def subset( # noqa: D102 - self, src_ixs: Optional[ArrayLike] = None, tgt_ixs: Optional[ArrayLike] = None + self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None ) -> "BaseSolverOutput": - raise NotImplementedError("Not yet implemented.") + mat = self.transport_matrix + if src_ixs is not None: + mat = mat[src_ixs] + if tgt_ixs is not None: + mat = mat[:, tgt_ixs] + + return type(self)(mat, cost=self.cost, converged=self.converged, is_linear=self._is_linear) @property def cost(self) -> float: # noqa: D102 From 6d1ac469276bba1b7bcc1c1a8c63d46bbdd02abc Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:38:23 +0200 Subject: [PATCH 07/12] Update computation for `percentile` --- src/moscot/base/output.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 55dad0bf0..34935c373 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -252,19 +252,9 @@ def _sparsify(batch: int, threshold: float) -> sp.csr_matrix: raise ValueError("If `mode = 'percentile'`, `threshold` cannot be `None`.") rng = np.random.default_rng(seed=seed) n_samples = n_samples if n_samples is not None else batch_size - k = min(n_samples, n) - x = np.zeros((m, k)) - rows = rng.choice(m, size=k) - x[rows, np.arange(k)] = 1.0 - res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors - thr = np.percentile(res, value) + ixs = rng.choice(np.arange(n), size=n_samples, replace=False) + thr = np.percentile(self.subset(ixs).transport_matrix, value) elif mode == "min_row": - # thr = np.inf - # for batch in tqdm(range(0, m, batch_size)): - # x = np.eye(m, min(batch_size, m - batch), -(min(batch, m))) - # res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors - # thr = min(thr, float(res.max(axis=1).min())) - results = Parallel(n_jobs=n_jobs, verbose=3)( delayed(_min_row)(batch) for batch in tqdm(range(0, n, batch_size)) ) From c4006be453a5fb797542c9ebf25d2233a14d31f4 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:39:53 +0200 Subject: [PATCH 08/12] Update docs --- src/moscot/base/output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 34935c373..e1f69431d 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -223,7 +223,7 @@ def sparsify( seed Random seed needed for sampling if ``mode = 'percentile'``. n_jobs - TODO + Number of concurrent jobs to use. If :math:`-1`, use all cores. Returns ------- From 02447e459d010791d21f6d7d7a0bc3a34e980cf9 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:41:16 +0200 Subject: [PATCH 09/12] Update docs v2 --- src/moscot/base/output.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index e1f69431d..7a7e138a9 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -91,7 +91,19 @@ def _ones(self, n: int) -> ArrayLike: def subset( # noqa: D102 self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None ) -> "BaseSolverOutput": - """TODO.""" + """Subset the transport matrix. + + Parameters + ---------- + src_ixs + Source indices. If :obj:`None`, don't subset the rows. + tgt_ixs + Target indices. If :obj:`None`, don't subset the columns. + + Returns + ------- + The subset of a transport matrix. + """ def push(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike: """Push mass through the :attr:`transport_matrix`. From a18d2d764d6c85d14b60ded40735d580cac9b27c Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:52:02 +0200 Subject: [PATCH 10/12] Fix `OTTOutput` sparsification, add logger --- src/moscot/backends/ott/output.py | 2 ++ src/moscot/base/output.py | 18 +++++++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index 894dfeb33..1a15404a4 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -192,6 +192,8 @@ def subset( # noqa: D102 g, y, g = g[tgt_ixs], y[tgt_ixs], b[tgt_ixs] _ = aux_data.pop("batch_size", None) + # TODO(michalk8): remove this in the new ott-jax release + _ = aux_data.pop("epsilon", None) geom = pointcloud.PointCloud(x, y, epsilon=eps, cost_fn=cost_fn, **aux_data) prob = linear_problem.LinearProblem(geom, a=a, b=b, tau_a=prob.tau_a, tau_b=prob.tau_b) diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 7a7e138a9..48964f5ee 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Iterable, Literal, Optional, Tuple, Union from joblib import Parallel, delayed -from tqdm import tqdm +from tqdm.auto import trange import numpy as np import scipy.sparse as sp @@ -199,6 +199,7 @@ def sparsify( n_samples: Optional[int] = None, seed: Optional[int] = None, n_jobs: int = 1, + **kwargs: Any, ) -> "MatrixSolverOutput": """Sparsify the :attr:`transport_matrix`. @@ -236,6 +237,8 @@ def sparsify( Random seed needed for sampling if ``mode = 'percentile'``. n_jobs Number of concurrent jobs to use. If :math:`-1`, use all cores. + kwargs + Keyword arguments for :class:`~joblib.Parallel`. Returns ------- @@ -267,16 +270,17 @@ def _sparsify(batch: int, threshold: float) -> sp.csr_matrix: ixs = rng.choice(np.arange(n), size=n_samples, replace=False) thr = np.percentile(self.subset(ixs).transport_matrix, value) elif mode == "min_row": - results = Parallel(n_jobs=n_jobs, verbose=3)( - delayed(_min_row)(batch) for batch in tqdm(range(0, n, batch_size)) + logger.info("Computing threshold for `mode='min_row'`") + results = Parallel(n_jobs=n_jobs, **kwargs)( + delayed(_min_row)(batch) for batch in trange(0, n, batch_size, unit="batch") ) thr = np.min(results) else: raise NotImplementedError(mode) - n, _ = self.shape - results = Parallel(n_jobs=n_jobs, verbose=3)( - delayed(_sparsify)(batch, thr) for batch in tqdm(range(0, n, batch_size)) + logger.info(f"Using `threshold={thr:.6}` for sparsification") + results = Parallel(n_jobs=n_jobs, **kwargs)( + delayed(_sparsify)(batch, thr) for batch in trange(0, n, batch_size, unit="batch") ) return MatrixSolverOutput( @@ -380,7 +384,7 @@ def subset( # noqa: D102 ) -> "BaseSolverOutput": mat = self.transport_matrix if src_ixs is not None: - mat = mat[src_ixs] + mat = mat[src_ixs, :] if tgt_ixs is not None: mat = mat[:, tgt_ixs] From d20a7af7b680ec36979ff8e115a27f771d72339a Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:53:14 +0200 Subject: [PATCH 11/12] More apt test class name --- tests/solvers/test_base_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/solvers/test_base_solver.py b/tests/solvers/test_base_solver.py index 536d51863..ef298da2d 100644 --- a/tests/solvers/test_base_solver.py +++ b/tests/solvers/test_base_solver.py @@ -11,7 +11,7 @@ from tests._utils import ATOL, RTOL, MockSolverOutput -class TestBaseSolverOutput: +class TestSparsification: @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("threshold", [0.0, 1e-1, 1.0]) @pytest.mark.parametrize("shape", [(7, 2), (91, 103)]) From 0e7b5b7769a3f341d659c79858a3ccb8103cd562 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 8 Jun 2023 12:06:11 +0200 Subject: [PATCH 12/12] Add joblib to intersphinx --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index 5959a1ab5..a89cd10d3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -54,6 +54,7 @@ "anndata": ("https://anndata.readthedocs.io/en/latest/", None), "scanpy": ("https://scanpy.readthedocs.io/en/latest/", None), "squidpy": ("https://squidpy.readthedocs.io/en/latest/", None), + "joblib": ("https://joblib.readthedocs.io/en/latest/", None), } master_doc = "index" pygments_style = "tango"