From 5b9a9434a4331680a653cda1405b3f67b3e02470 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hande=20G=C3=B6z=C3=BCkan?= <2099645+hndgzkn@users.noreply.github.com> Date: Tue, 27 Sep 2022 17:49:02 +0200 Subject: [PATCH] Fix max iter (#54) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix local n_seg * fix max_iter count according to alphacsc implementation * fix linting * Update dicodile/workers/dicod_worker.py Co-authored-by: Thomas Moreau * fix segmentation * fix number of segments * fix seg size * revert pdb line * revert accumulator change * consider remainder while calculating number of segments * revert local_seg_support * CLN more compact n_seg computet * Update dicodile/utils/segmentation.py * Update dicodile/utils/segmentation.py Co-authored-by: Hande Gözükan <2099645+hndgzkn@users.noreply.github.com> * adds a test to test the number of segments Co-authored-by: Thomas Moreau --- dicodile/tests/test_config.py | 8 +++---- dicodile/utils/segmentation.py | 3 ++- dicodile/utils/tests/test_segmentation.py | 26 ++++++++++++++++++++--- dicodile/workers/dicod_worker.py | 5 +++-- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/dicodile/tests/test_config.py b/dicodile/tests/test_config.py index 550f251d..6ee3e5c3 100644 --- a/dicodile/tests/test_config.py +++ b/dicodile/tests/test_config.py @@ -7,7 +7,7 @@ def test_dicodile_home(monkeypatch): _set_env(monkeypatch, { "DICODILE_DATA_HOME": "/home/unittest/dicodile" }) - assert(get_data_home() == Path("/home/unittest/dicodile/dicodile")) + assert get_data_home() == Path("/home/unittest/dicodile/dicodile") def test_XDG_DATA_home(monkeypatch): @@ -15,7 +15,7 @@ def test_XDG_DATA_home(monkeypatch): "DICODILE_DATA_HOME": None, "XDG_DATA_HOME": "/home/unittest/data" }) - assert(get_data_home() == Path("/home/unittest/data/dicodile")) + assert get_data_home() == Path("/home/unittest/data/dicodile") def test_default_home(monkeypatch): @@ -24,7 +24,7 @@ def test_default_home(monkeypatch): "DICODILE_DATA_HOME": None, "XDG_DATA_HOME": None, }) - assert(get_data_home() == Path("/home/default/data/dicodile")) + assert get_data_home() == Path("/home/default/data/dicodile") def test_dicodile_home_has_priority_over_xdg_data_home(monkeypatch): @@ -32,7 +32,7 @@ def test_dicodile_home_has_priority_over_xdg_data_home(monkeypatch): "DICODILE_DATA_HOME": "/home/unittest/dicodile", "XDG_DATA_HOME": "/home/unittest/data" }) - assert(get_data_home() == Path("/home/unittest/dicodile/dicodile")) + assert get_data_home() == Path("/home/unittest/dicodile/dicodile") def _set_env(monkeypatch, d): diff --git a/dicodile/utils/segmentation.py b/dicodile/utils/segmentation.py index 7c3ced7c..239ad7bb 100644 --- a/dicodile/utils/segmentation.py +++ b/dicodile/utils/segmentation.py @@ -85,7 +85,8 @@ def compute_n_seg(self): self.n_seg_per_axis = [] for size_ax, size_seg_ax in zip(self.signal_support, self.seg_support): # Make sure that n_seg_ax is of type int (and not np.int*) - n_seg_ax = max(1, int(size_ax // size_seg_ax)) + n_seg_ax = max(1, int(size_ax // size_seg_ax) + + ((size_ax % size_seg_ax) != 0)) self.n_seg_per_axis.append(n_seg_ax) self.effective_n_seg *= n_seg_ax diff --git a/dicodile/utils/tests/test_segmentation.py b/dicodile/utils/tests/test_segmentation.py index 74f5b361..be23ffb1 100644 --- a/dicodile/utils/tests/test_segmentation.py +++ b/dicodile/utils/tests/test_segmentation.py @@ -187,9 +187,9 @@ def test_inner_coordinate(): if w_rank == n_seg[1] - 1: assert segments.is_contained_coordinate( - i_seg, - (seg_support[0] - overlap[0] - 1, seg_support[1] - 1), - inner=True) + i_seg, + (seg_support[0] - overlap[0] - 1, seg_support[1] - 1), + inner=True) else: assert not segments.is_contained_coordinate( i_seg, (seg_support[0] - overlap[0] - 1, @@ -267,3 +267,23 @@ def test_padding_to_overlap(): overlap = seg.get_padding_to_overlap(i_seg) z = np.pad(z, overlap, mode='constant') assert z.shape == seg_support_all + + +def test_segments(): + """Tests if the number of segments is computed correctly.""" + seg_support = [9] + inner_bounds = [[0, 252]] + full_support = (252,) + + seg = Segmentation(n_seg=None, seg_support=seg_support, + inner_bounds=inner_bounds, full_support=full_support) + seg.compute_n_seg() + + assert seg.effective_n_seg == 28 + + seg_support = [10] + seg = Segmentation(n_seg=None, seg_support=seg_support, + inner_bounds=inner_bounds, full_support=full_support) + seg.compute_n_seg() + + assert seg.effective_n_seg == 26 diff --git a/dicodile/workers/dicod_worker.py b/dicodile/workers/dicod_worker.py index 8509cbf6..d9b67b3e 100644 --- a/dicodile/workers/dicod_worker.py +++ b/dicodile/workers/dicod_worker.py @@ -90,6 +90,7 @@ def compute_z_hat(self): deadline = t_start + self.timeout else: deadline = None + for ii in range(self.max_iter): # Display the progress of the algorithm self.progress(ii, max_ii=self.max_iter, unit="iterations", @@ -110,7 +111,6 @@ def compute_z_hat(self): t_run += selection_duration else: k0, pt0, dz = None, None, 0 - # update the accumulator for 'random' strategy accumulator = max(abs(dz), accumulator) @@ -249,7 +249,7 @@ def init_cd_variables(self): constants['DtD'], n_atoms, atom_support - ) + ) self.constants = constants # List of all pending messages sent @@ -764,6 +764,7 @@ def recv_signal(self): inner_bounds=inner_bounds, full_support=worker_support) + self.max_iter *= self.local_segments.effective_n_seg self.synchronize_workers(with_main=True) return X_worker, z0