diff --git a/docs/conf.py b/docs/conf.py index 5959a1ab5..a89cd10d3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -54,6 +54,7 @@ "anndata": ("https://anndata.readthedocs.io/en/latest/", None), "scanpy": ("https://scanpy.readthedocs.io/en/latest/", None), "squidpy": ("https://squidpy.readthedocs.io/en/latest/", None), + "joblib": ("https://joblib.readthedocs.io/en/latest/", None), } master_doc = "index" pygments_style = "tango" diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index 83166e240..1a15404a4 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -153,20 +153,6 @@ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike: return self._output.apply(x, axis=1 - forward) return self._output.apply(x.T, axis=1 - forward).T # convert to batch first - @property - def shape(self) -> Tuple[int, int]: # noqa: D102 - if isinstance(self._output, OTTSinkhornOutput): - return self._output.f.shape[0], self._output.g.shape[0] - return self._output.geom.shape - - @property - def transport_matrix(self) -> ArrayLike: # noqa: D102 - return self._output.matrix - - @property - def is_linear(self) -> bool: # noqa: D102 - return isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput)) - def to(self, device: Optional[Device_t] = None) -> "OTTOutput": # noqa: D102 if isinstance(device, str) and ":" in device: device, ix = device.split(":") @@ -182,6 +168,55 @@ def to(self, device: Optional[Device_t] = None) -> "OTTOutput": # noqa: D102 return OTTOutput(jax.device_put(self._output, device)) + def subset( # noqa: D102 + self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None + ) -> "OTTOutput": + from ott.geometry import pointcloud + from ott.problems.linear import linear_problem + from ott.solvers.linear import sinkhorn + + assert self.is_linear, "Quadratic is not yet available." + assert not self.is_low_rank, "Low-rank is not yet available." + + f, g = self.potentials # type: ignore[misc] + prob = self._output.ot_prob + assert isinstance(prob.geom, pointcloud.PointCloud), "Only available for point clouds." + + (x, y, *_, cost_fn), aux_data = prob.geom.tree_flatten() + a, b = prob.a, prob.b + eps = prob.epsilon + + if src_ixs is not None: + f, x, a = f[src_ixs], x[src_ixs], a[src_ixs] + if tgt_ixs is not None: + g, y, g = g[tgt_ixs], y[tgt_ixs], b[tgt_ixs] + + _ = aux_data.pop("batch_size", None) + # TODO(michalk8): remove this in the new ott-jax release + _ = aux_data.pop("epsilon", None) + geom = pointcloud.PointCloud(x, y, epsilon=eps, cost_fn=cost_fn, **aux_data) + + prob = linear_problem.LinearProblem(geom, a=a, b=b, tau_a=prob.tau_a, tau_b=prob.tau_b) + out = sinkhorn.SinkhornOutput( + f=f, g=g, errors=self._output.errors, reg_ot_cost=self._output.reg_ot_cost, ot_prob=prob + ) + + return type(self)(out) + + @property + def shape(self) -> Tuple[int, int]: # noqa: D102 + if isinstance(self._output, OTTSinkhornOutput): + return self._output.f.shape[0], self._output.g.shape[0] + return self._output.geom.shape + + @property + def transport_matrix(self) -> ArrayLike: # noqa: D102 + return self._output.matrix + + @property + def is_linear(self) -> bool: # noqa: D102 + return isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput)) + @property def cost(self) -> float: # noqa: D102 if isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput)): diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 9cf53bbaa..48964f5ee 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -1,7 +1,10 @@ from abc import ABC, abstractmethod from copy import copy from functools import partial -from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple +from typing import Any, Callable, Iterable, Literal, Optional, Tuple, Union + +from joblib import Parallel, delayed +from tqdm.auto import trange import numpy as np import scipy.sparse as sp @@ -84,6 +87,24 @@ def is_low_rank(self) -> bool: def _ones(self, n: int) -> ArrayLike: pass + @abstractmethod + def subset( # noqa: D102 + self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None + ) -> "BaseSolverOutput": + """Subset the transport matrix. + + Parameters + ---------- + src_ixs + Source indices. If :obj:`None`, don't subset the rows. + tgt_ixs + Target indices. If :obj:`None`, don't subset the columns. + + Returns + ------- + The subset of a transport matrix. + """ + def push(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike: """Push mass through the :attr:`transport_matrix`. @@ -177,6 +198,8 @@ def sparsify( batch_size: int = 1024, n_samples: Optional[int] = None, seed: Optional[int] = None, + n_jobs: int = 1, + **kwargs: Any, ) -> "MatrixSolverOutput": """Sparsify the :attr:`transport_matrix`. @@ -212,11 +235,28 @@ def sparsify( ``batch_size``. seed Random seed needed for sampling if ``mode = 'percentile'``. + n_jobs + Number of concurrent jobs to use. If :math:`-1`, use all cores. + kwargs + Keyword arguments for :class:`~joblib.Parallel`. Returns ------- Solve output with a sparsified transport matrix. """ + + def _min_row(batch: int) -> float: + res = self.subset(slice(batch, batch + batch_size)).transport_matrix + return float(res.max(axis=1).min()) + + def _sparsify(batch: int, threshold: float) -> sp.csr_matrix: + res = self.subset(slice(batch, batch + batch_size)).transport_matrix + if not isinstance(res, np.ndarray): + res = np.array(res) + + res[res < threshold] = 0.0 + return sp.csr_matrix(res) + n, m = self.shape if mode == "threshold": if value is None: @@ -225,32 +265,26 @@ def sparsify( elif mode == "percentile": if value is None: raise ValueError("If `mode = 'percentile'`, `threshold` cannot be `None`.") - rng = np.random.RandomState(seed=seed) + rng = np.random.default_rng(seed=seed) n_samples = n_samples if n_samples is not None else batch_size - k = min(n_samples, n) - x = np.zeros((m, k)) - rows = rng.choice(m, size=k) - x[rows, np.arange(k)] = 1.0 - res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors - thr = np.percentile(res, value) + ixs = rng.choice(np.arange(n), size=n_samples, replace=False) + thr = np.percentile(self.subset(ixs).transport_matrix, value) elif mode == "min_row": - thr = np.inf - for batch in range(0, m, batch_size): - x = np.eye(m, min(batch_size, m - batch), -(min(batch, m))) - res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors - thr = min(thr, float(res.max(axis=1).min())) + logger.info("Computing threshold for `mode='min_row'`") + results = Parallel(n_jobs=n_jobs, **kwargs)( + delayed(_min_row)(batch) for batch in trange(0, n, batch_size, unit="batch") + ) + thr = np.min(results) else: raise NotImplementedError(mode) - k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack) - tmaps_sparse: List[sp.csr_matrix] = [] - for batch in range(0, k, batch_size): - x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) - res = np.array(func(x, scale_by_marginals=False)) - res[res < thr] = 0.0 - tmaps_sparse.append(sp.csr_matrix(res.T if n < m else res)) + logger.info(f"Using `threshold={thr:.6}` for sparsification") + results = Parallel(n_jobs=n_jobs, **kwargs)( + delayed(_sparsify)(batch, thr) for batch in trange(0, n, batch_size, unit="batch") + ) + return MatrixSolverOutput( - transport_matrix=fn_stack(tmaps_sparse), cost=self.cost, converged=self.converged, is_linear=self.is_linear + transport_matrix=sp.vstack(results), cost=self.cost, converged=self.converged, is_linear=self.is_linear ) @property @@ -345,6 +379,17 @@ def to( # noqa: D102 obj._transport_matrix = obj.transport_matrix.astype(dtype) return obj + def subset( # noqa: D102 + self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None + ) -> "BaseSolverOutput": + mat = self.transport_matrix + if src_ixs is not None: + mat = mat[src_ixs, :] + if tgt_ixs is not None: + mat = mat[:, tgt_ixs] + + return type(self)(mat, cost=self.cost, converged=self.converged, is_linear=self._is_linear) + @property def cost(self) -> float: # noqa: D102 return self._cost diff --git a/tests/solvers/test_base_solver.py b/tests/solvers/test_base_solver.py index 536d51863..ef298da2d 100644 --- a/tests/solvers/test_base_solver.py +++ b/tests/solvers/test_base_solver.py @@ -11,7 +11,7 @@ from tests._utils import ATOL, RTOL, MockSolverOutput -class TestBaseSolverOutput: +class TestSparsification: @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("threshold", [0.0, 1e-1, 1.0]) @pytest.mark.parametrize("shape", [(7, 2), (91, 103)])