Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tqdm in sparsification #550

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"anndata": ("https://anndata.readthedocs.io/en/latest/", None),
"scanpy": ("https://scanpy.readthedocs.io/en/latest/", None),
"squidpy": ("https://squidpy.readthedocs.io/en/latest/", None),
"joblib": ("https://joblib.readthedocs.io/en/latest/", None),
}
master_doc = "index"
pygments_style = "tango"
Expand Down
63 changes: 49 additions & 14 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,6 @@ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
return self._output.apply(x, axis=1 - forward)
return self._output.apply(x.T, axis=1 - forward).T # convert to batch first

@property
def shape(self) -> Tuple[int, int]: # noqa: D102
if isinstance(self._output, OTTSinkhornOutput):
return self._output.f.shape[0], self._output.g.shape[0]
return self._output.geom.shape

@property
def transport_matrix(self) -> ArrayLike: # noqa: D102
return self._output.matrix

@property
def is_linear(self) -> bool: # noqa: D102
return isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput))

def to(self, device: Optional[Device_t] = None) -> "OTTOutput": # noqa: D102
if isinstance(device, str) and ":" in device:
device, ix = device.split(":")
Expand All @@ -182,6 +168,55 @@ def to(self, device: Optional[Device_t] = None) -> "OTTOutput": # noqa: D102

return OTTOutput(jax.device_put(self._output, device))

def subset( # noqa: D102
self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None
) -> "OTTOutput":
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

assert self.is_linear, "Quadratic is not yet available."
assert not self.is_low_rank, "Low-rank is not yet available."

f, g = self.potentials # type: ignore[misc]
prob = self._output.ot_prob
assert isinstance(prob.geom, pointcloud.PointCloud), "Only available for point clouds."

(x, y, *_, cost_fn), aux_data = prob.geom.tree_flatten()
a, b = prob.a, prob.b
eps = prob.epsilon

if src_ixs is not None:
f, x, a = f[src_ixs], x[src_ixs], a[src_ixs]
if tgt_ixs is not None:
g, y, g = g[tgt_ixs], y[tgt_ixs], b[tgt_ixs]

_ = aux_data.pop("batch_size", None)
# TODO(michalk8): remove this in the new ott-jax release
_ = aux_data.pop("epsilon", None)
geom = pointcloud.PointCloud(x, y, epsilon=eps, cost_fn=cost_fn, **aux_data)

prob = linear_problem.LinearProblem(geom, a=a, b=b, tau_a=prob.tau_a, tau_b=prob.tau_b)
out = sinkhorn.SinkhornOutput(
f=f, g=g, errors=self._output.errors, reg_ot_cost=self._output.reg_ot_cost, ot_prob=prob
)

return type(self)(out)

@property
def shape(self) -> Tuple[int, int]: # noqa: D102
if isinstance(self._output, OTTSinkhornOutput):
return self._output.f.shape[0], self._output.g.shape[0]
return self._output.geom.shape

@property
def transport_matrix(self) -> ArrayLike: # noqa: D102
return self._output.matrix

@property
def is_linear(self) -> bool: # noqa: D102
return isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput))

@property
def cost(self) -> float: # noqa: D102
if isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput)):
Expand Down
87 changes: 66 additions & 21 deletions src/moscot/base/output.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from abc import ABC, abstractmethod
from copy import copy
from functools import partial
from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple
from typing import Any, Callable, Iterable, Literal, Optional, Tuple, Union

from joblib import Parallel, delayed
from tqdm.auto import trange

import numpy as np
import scipy.sparse as sp
Expand Down Expand Up @@ -84,6 +87,24 @@ def is_low_rank(self) -> bool:
def _ones(self, n: int) -> ArrayLike:
pass

@abstractmethod
def subset( # noqa: D102
self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None
) -> "BaseSolverOutput":
"""Subset the transport matrix.

Parameters
----------
src_ixs
Source indices. If :obj:`None`, don't subset the rows.
tgt_ixs
Target indices. If :obj:`None`, don't subset the columns.

Returns
-------
The subset of a transport matrix.
"""

def push(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike:
"""Push mass through the :attr:`transport_matrix`.

Expand Down Expand Up @@ -177,6 +198,8 @@ def sparsify(
batch_size: int = 1024,
n_samples: Optional[int] = None,
seed: Optional[int] = None,
n_jobs: int = 1,
**kwargs: Any,
) -> "MatrixSolverOutput":
"""Sparsify the :attr:`transport_matrix`.

Expand Down Expand Up @@ -212,11 +235,28 @@ def sparsify(
``batch_size``.
seed
Random seed needed for sampling if ``mode = 'percentile'``.
n_jobs
Number of concurrent jobs to use. If :math:`-1`, use all cores.
kwargs
Keyword arguments for :class:`~joblib.Parallel`.

Returns
-------
Solve output with a sparsified transport matrix.
"""

def _min_row(batch: int) -> float:
res = self.subset(slice(batch, batch + batch_size)).transport_matrix
return float(res.max(axis=1).min())

def _sparsify(batch: int, threshold: float) -> sp.csr_matrix:
res = self.subset(slice(batch, batch + batch_size)).transport_matrix
if not isinstance(res, np.ndarray):
res = np.array(res)

res[res < threshold] = 0.0
return sp.csr_matrix(res)

n, m = self.shape
if mode == "threshold":
if value is None:
Expand All @@ -225,32 +265,26 @@ def sparsify(
elif mode == "percentile":
if value is None:
raise ValueError("If `mode = 'percentile'`, `threshold` cannot be `None`.")
rng = np.random.RandomState(seed=seed)
rng = np.random.default_rng(seed=seed)
n_samples = n_samples if n_samples is not None else batch_size
k = min(n_samples, n)
x = np.zeros((m, k))
rows = rng.choice(m, size=k)
x[rows, np.arange(k)] = 1.0
res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors
thr = np.percentile(res, value)
ixs = rng.choice(np.arange(n), size=n_samples, replace=False)
thr = np.percentile(self.subset(ixs).transport_matrix, value)
elif mode == "min_row":
thr = np.inf
for batch in range(0, m, batch_size):
x = np.eye(m, min(batch_size, m - batch), -(min(batch, m)))
res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors
thr = min(thr, float(res.max(axis=1).min()))
logger.info("Computing threshold for `mode='min_row'`")
results = Parallel(n_jobs=n_jobs, **kwargs)(
delayed(_min_row)(batch) for batch in trange(0, n, batch_size, unit="batch")
)
thr = np.min(results)
else:
raise NotImplementedError(mode)

k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack)
tmaps_sparse: List[sp.csr_matrix] = []
for batch in range(0, k, batch_size):
x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float)
res = np.array(func(x, scale_by_marginals=False))
res[res < thr] = 0.0
tmaps_sparse.append(sp.csr_matrix(res.T if n < m else res))
logger.info(f"Using `threshold={thr:.6}` for sparsification")
results = Parallel(n_jobs=n_jobs, **kwargs)(
delayed(_sparsify)(batch, thr) for batch in trange(0, n, batch_size, unit="batch")
)

return MatrixSolverOutput(
transport_matrix=fn_stack(tmaps_sparse), cost=self.cost, converged=self.converged, is_linear=self.is_linear
transport_matrix=sp.vstack(results), cost=self.cost, converged=self.converged, is_linear=self.is_linear
)

@property
Expand Down Expand Up @@ -345,6 +379,17 @@ def to( # noqa: D102
obj._transport_matrix = obj.transport_matrix.astype(dtype)
return obj

def subset( # noqa: D102
self, src_ixs: Optional[Union[slice, ArrayLike]] = None, tgt_ixs: Optional[Union[slice, ArrayLike]] = None
) -> "BaseSolverOutput":
mat = self.transport_matrix
if src_ixs is not None:
mat = mat[src_ixs, :]
if tgt_ixs is not None:
mat = mat[:, tgt_ixs]

return type(self)(mat, cost=self.cost, converged=self.converged, is_linear=self._is_linear)

@property
def cost(self) -> float: # noqa: D102
return self._cost
Expand Down
2 changes: 1 addition & 1 deletion tests/solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tests._utils import ATOL, RTOL, MockSolverOutput


class TestBaseSolverOutput:
class TestSparsification:
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("threshold", [0.0, 1e-1, 1.0])
@pytest.mark.parametrize("shape", [(7, 2), (91, 103)])
Expand Down