Skip to content

Commit

Permalink
ENH add distributed version of get_max_error_patch (#51)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Moreau <thomas.moreau.2010@gmail.com>
  • Loading branch information
rprimet and tomMoral authored Feb 23, 2022
1 parent 2588788 commit 1e4bc06
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 22 deletions.
6 changes: 3 additions & 3 deletions dicodile/_dicodile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@


from .update_d.update_d import update_d
from .utils.dictionary import prox_d
from .utils.dictionary import get_lambda_max
from .utils.dictionary import get_max_error_dict

from .update_z.distributed_sparse_encoder import DistributedSparseEncoder

Expand Down Expand Up @@ -185,8 +185,8 @@ def dicodile(X, D_init, reg=.1, n_iter=100, eps=1e-5, window=False,
null_atom_indices = np.where(z_nnz == 0)[0]
if len(null_atom_indices) > 0:
k0 = null_atom_indices[0]
z_hat = encoder.get_z_hat()
D_hat[k0] = get_max_error_dict(X, z_hat, D_hat, window=window)[0]
d0 = encoder.compute_and_get_max_error_patch(window=window)
D_hat[k0] = prox_d(d0)
if verbose > 1:
print('[INFO:{}] Resampled atom {}'.format(name, k0))

Expand Down
29 changes: 29 additions & 0 deletions dicodile/tests/test_dicodile.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import pytest

import numpy as np

from dicodile import dicodile
from dicodile.data.simulate import simulate_data

Expand All @@ -13,3 +17,28 @@ def test_dicodile():
X, D, reg=.1, z_positive=True, n_iter=10, eps=1e-4,
n_workers=1, verbose=2, tol=1e-10)
assert is_deacreasing(pobj)


@pytest.mark.parametrize("n_workers", [1]) # XXX [1,2,3]
def test_dicodile_greedy(n_workers):
n_channels = 3
n_atoms = 2
n_times_atom = 10
n_times = 100

X, D, _ = simulate_data(n_times=n_times, n_times_atom=n_times_atom,
n_atoms=n_atoms, n_channels=n_channels,
noise_level=1e-5, random_state=42)

X = np.zeros((n_channels, n_times))
X[:, 45:51] = np.ones((n_channels, 6)) * np.array([1, 0.5, 0.25]).reshape(3, 1) # noqa: E501

# Starts with a single random atom, expect to learn others
# from the largest reconstruction error patch
D[1:] *= 1e-6

D_hat, z_hat, pobj, times = dicodile(
X, D, reg=.1, z_positive=True, n_iter=2, eps=1e-4,
n_workers=n_workers, verbose=2, tol=1e-10)

assert is_deacreasing(pobj)
5 changes: 5 additions & 0 deletions dicodile/update_z/dicod.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,5 +390,10 @@ def recv_cost(comm):
return cost[0]


def recv_max_error_patches(comm):
max_error_patches = comm.gather(None, root=MPI.ROOT)
return max_error_patches


# Update the docstring
dicod.__doc__.format(STRATEGIES)
19 changes: 18 additions & 1 deletion dicodile/update_z/distributed_sparse_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from .dicod import recv_z_hat, recv_z_nnz
from .dicod import _gather_run_statistics
from .dicod import _send_task, _send_D, _send_signal
from .dicod import recv_cost, recv_sufficient_statistics
from .dicod import recv_cost, recv_max_error_patches,\
recv_sufficient_statistics


class DistributedSparseEncoder:
Expand Down Expand Up @@ -146,6 +147,22 @@ def get_sufficient_statistics(self):
verbose=self.verbose)
return recv_sufficient_statistics(self.workers.comm, self.D_shape)

def compute_and_get_max_error_patch(self, window=False):
# Send the command to distributed workers as well
# as the window parameter
self.workers.send_command(
constants.TAG_DICODILE_GET_MAX_ERROR_PATCH,
verbose=self.verbose
)
self.workers.comm.bcast({'window': window}, root=MPI.ROOT)

# Receive the max patch for each worker.
max_errors = recv_max_error_patches(self.workers.comm)

# find largest patch in max_errors and return it
patch_idx = np.argmax([item[1] for item in max_errors])
return max_errors[patch_idx][0]

def release_workers(self):
self.workers.send_command(
constants.TAG_DICODILE_STOP)
Expand Down
38 changes: 37 additions & 1 deletion dicodile/update_z/tests/test_distributed_sparse_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from dicodile.utils import check_random_state
from dicodile.utils.dictionary import compute_DtD, get_D
from dicodile.utils.dictionary import compute_DtD, get_D, get_max_error_patch
from dicodile.utils.csc import compute_objective
from dicodile.utils.csc import compute_ztX, compute_ztz

Expand Down Expand Up @@ -92,3 +92,39 @@ def test_pre_computed_DtD_should_always_be_passed_to_set_worker_D():

with pytest.raises(ValueError, match=r"pre-computed value DtD"):
encoder.set_worker_D(D)


@pytest.mark.parametrize("n_workers", [1, 2, 3])
def test_compute_max_error_patch(n_workers):
rng = check_random_state(42)

n_atoms = 2
n_channels = 3
n_times_atom = 10
n_times = 10 * n_times_atom
reg = 5e-1

params = dict(tol=1e-2, n_seg='auto', timing=False, timeout=None,
verbose=100, strategy='greedy', max_iter=100000,
soft_lock='border', z_positive=True, return_ztz=False,
freeze_support=False, warm_start=False, random_state=27)

X = rng.randn(n_channels, n_times)
D = rng.randn(n_atoms, n_channels, n_times_atom)
sum_axis = tuple(range(1, D.ndim))
D /= np.sqrt(np.sum(D * D, axis=sum_axis, keepdims=True))

encoder = DistributedSparseEncoder(n_workers=n_workers)

encoder.init_workers(X, D, reg, params, DtD=None)

encoder.process_z_hat()
z_hat = encoder.get_z_hat()

max_error_patch = encoder.compute_and_get_max_error_patch()
assert max_error_patch.shape == (n_channels, n_times_atom)

reference_patch, _ = get_max_error_patch(X, z_hat, D)
assert np.allclose(max_error_patch, reference_patch)

encoder.shutdown_workers()
2 changes: 1 addition & 1 deletion dicodile/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
TAG_DICODILE_SET_SIGNAL = 23
TAG_DICODILE_SET_PARAMS = 24
TAG_DICODILE_SET_TASK = 25

TAG_DICODILE_GET_MAX_ERROR_PATCH = 26

# inter-process message size
SIZE_MSG = 4
Expand Down
26 changes: 19 additions & 7 deletions dicodile/utils/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .shape_helpers import get_valid_support


def get_max_error_dict(X, z, D, window=False):
def get_max_error_patch(X, z, D, window=False, local_segments=None):
"""Get the maximal reconstruction error patch from the data as a new atom
This idea is used for instance in [Yellin2017]
Expand All @@ -32,7 +32,9 @@ def get_max_error_dict(X, z, D, window=False):
IMAGING BY CONVOLUTIONAL SPARSE DICTIONARY LEARNING AND CODING.
"""
atom_support = D.shape[2:]
patch_rec_error = _patch_reconstruction_error(X, z, D, window=window)
patch_rec_error, X = _patch_reconstruction_error(
X, z, D, window=window, local_segments=local_segments
)
i0 = patch_rec_error.argmax()
pt0 = np.unravel_index(i0, patch_rec_error.shape)

Expand All @@ -41,9 +43,7 @@ def get_max_error_dict(X, z, D, window=False):
])
d0 = X[d0_slice]

d0 = prox_d(d0)

return d0
return d0, patch_rec_error[i0]


def prox_d(D):
Expand All @@ -53,13 +53,25 @@ def prox_d(D):
return D


def _patch_reconstruction_error(X, z, D, window=False):
def _patch_reconstruction_error(X, z, D, window=False, local_segments=None):
"""Return the reconstruction error for each patches of size (P, L)."""
n_trials, n_channels, *sig_support = X.shape
atom_support = D.shape[2:]

X_hat = reconstruct(z, D)

# When computing a distributed patch reconstruction error,
# we take the bounds into account.
# ``local_segments=None`` is used when computing the reconstruction
# error on the full signal.
if local_segments is not None:
X_slice = (Ellipsis,) + tuple([
slice(start, end + size_atom_ax - 1)
for (start, end), size_atom_ax in zip(
local_segments.inner_bounds, atom_support)
])
X, X_hat = X[X_slice], X_hat[X_slice]

diff = (X - X_hat)
diff *= diff

Expand All @@ -74,7 +86,7 @@ def _patch_reconstruction_error(X, z, D, window=False):
convolution_op = signal.convolve

return np.sum([convolution_op(patch, diff_p, mode='valid')
for diff_p in diff], axis=0)
for diff_p in diff], axis=0), X


def get_lambda_max(X, D_hat):
Expand Down
6 changes: 3 additions & 3 deletions dicodile/utils/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, n_seg=None, seg_support=None, signal_support=None,
for size_full_ax, (_, end) in zip(self.full_support,
self.inner_bounds)])

# compute the size of each segments and the number of segments
# compute the size of each segment and the number of segments
if seg_support is not None:
if isinstance(seg_support, int):
seg_support = [seg_support] * self.n_axis
Expand Down Expand Up @@ -84,7 +84,7 @@ def compute_n_seg(self):
self.effective_n_seg = 1
self.n_seg_per_axis = []
for size_ax, size_seg_ax in zip(self.signal_support, self.seg_support):
# Make sure that n_seg_ax is of type in (and not np.int*)
# Make sure that n_seg_ax is of type int (and not np.int*)
n_seg_ax = max(1, int(size_ax // size_seg_ax))
self.n_seg_per_axis.append(n_seg_ax)
self.effective_n_seg *= n_seg_ax
Expand All @@ -95,7 +95,7 @@ def compute_seg_support(self):
self.effective_n_seg = 1
self.seg_support = []
for size_ax, n_seg_ax in zip(self.signal_support, self.n_seg_per_axis):
# Make sure that n_seg_ax is of type in (and not np.int*)
# Make sure that n_seg_ax is of type int (and not np.int*)
size_seg_ax = size_ax // n_seg_ax
size_seg_ax += (size_ax % n_seg_ax >= n_seg_ax // 2)
self.seg_support.append(size_seg_ax)
Expand Down
31 changes: 25 additions & 6 deletions dicodile/workers/dicod_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from dicodile.utils.csc import compute_ztz, compute_ztX
from dicodile.utils.shape_helpers import get_full_support
from dicodile.utils.order_iterator import get_order_iterator
from dicodile.utils.dictionary import D_shape, compute_DtD,\
norm_atoms_from_DtD_reshaped
from dicodile.utils.dictionary import D_shape, compute_DtD
from dicodile.utils.dictionary import get_max_error_patch
from dicodile.utils.dictionary import norm_atoms_from_DtD_reshaped

from dicodile.update_z.coordinate_descent import _select_coordinate
from dicodile.update_z.coordinate_descent import _check_convergence
Expand Down Expand Up @@ -199,7 +200,7 @@ def compute_z_hat(self):
# else:
# time.sleep(.001)

# Check is we reach the timeout
# Check if we reach the timeout
if deadline is not None and time.time() >= deadline:
self.stop_before_convergence(
"Reached timeout", ii + 1, n_coordinate_updates
Expand Down Expand Up @@ -538,16 +539,18 @@ def compute_cost(self):
cost = .5 * np.dot(diff, diff)
return cost + self.reg * abs(self.z_hat[inner_slice]).sum()

def return_z_hat(self):
def _get_z_hat(self):
if flags.GET_OVERLAP_Z_HAT:
res_slice = (Ellipsis,)
else:
res_slice = (Ellipsis,) + tuple([
slice(start, end)
for start, end in self.local_segments.inner_bounds
])
z_worker = self.z_hat[res_slice].ravel()
self.return_array(z_worker)
return self.z_hat[res_slice].ravel()

def return_z_hat(self):
self.return_array(self._get_z_hat())

def return_z_nnz(self):
res_slice = (Ellipsis,) + tuple([
Expand Down Expand Up @@ -577,6 +580,22 @@ def return_run_statistics(self, ii, n_coordinate_updates, runtime,
t_select_coord, t_update_coord]
self.gather_array(arr)

def compute_and_return_max_error_patch(self):
# receive window param
# cutting through abstractions here, refactor if needed
assert self._backend == "mpi"
comm = MPI.Comm.Get_parent()
params = comm.bcast(None, root=0)
assert 'window' in params

_, _, *atom_support = self.D.shape

max_error_patch, max_error = get_max_error_patch(
self.X_worker, self.z_hat, self.D, window=params['window'],
local_segments=self.local_segments
)
self.gather_array([max_error_patch, max_error])

###########################################################################
# Display utilities
###########################################################################
Expand Down
2 changes: 2 additions & 0 deletions dicodile/workers/dicodile_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ def dicodile_worker():
dicod_worker.recv_signal()
if tag == constants.TAG_DICODILE_SET_TASK:
dicod_worker.recv_task()
if tag == constants.TAG_DICODILE_GET_MAX_ERROR_PATCH:
dicod_worker.compute_and_return_max_error_patch()
tag = wait_message()

0 comments on commit 1e4bc06

Please sign in to comment.