Skip to content

Commit

Permalink
Fix max iter (#54)
Browse files Browse the repository at this point in the history
* fix local n_seg

* fix max_iter count according to alphacsc implementation

* fix linting

* Update dicodile/workers/dicod_worker.py

Co-authored-by: Thomas Moreau <thomas.moreau.2010@gmail.com>

* fix segmentation

* fix number of segments

* fix seg size

* revert pdb line

* revert accumulator change

* consider remainder while calculating number of segments

* revert local_seg_support

* CLN more compact n_seg computet

* Update dicodile/utils/segmentation.py

* Update dicodile/utils/segmentation.py

Co-authored-by: Hande Gözükan <2099645+hndgzkn@users.noreply.github.com>

* adds a test to test the number of segments

Co-authored-by: Thomas Moreau <thomas.moreau.2010@gmail.com>
  • Loading branch information
hndgzkn and tomMoral authored Sep 27, 2022
1 parent 1e4bc06 commit 5b9a943
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
8 changes: 4 additions & 4 deletions dicodile/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ def test_dicodile_home(monkeypatch):
_set_env(monkeypatch, {
"DICODILE_DATA_HOME": "/home/unittest/dicodile"
})
assert(get_data_home() == Path("/home/unittest/dicodile/dicodile"))
assert get_data_home() == Path("/home/unittest/dicodile/dicodile")


def test_XDG_DATA_home(monkeypatch):
_set_env(monkeypatch, {
"DICODILE_DATA_HOME": None,
"XDG_DATA_HOME": "/home/unittest/data"
})
assert(get_data_home() == Path("/home/unittest/data/dicodile"))
assert get_data_home() == Path("/home/unittest/data/dicodile")


def test_default_home(monkeypatch):
Expand All @@ -24,15 +24,15 @@ def test_default_home(monkeypatch):
"DICODILE_DATA_HOME": None,
"XDG_DATA_HOME": None,
})
assert(get_data_home() == Path("/home/default/data/dicodile"))
assert get_data_home() == Path("/home/default/data/dicodile")


def test_dicodile_home_has_priority_over_xdg_data_home(monkeypatch):
_set_env(monkeypatch, {
"DICODILE_DATA_HOME": "/home/unittest/dicodile",
"XDG_DATA_HOME": "/home/unittest/data"
})
assert(get_data_home() == Path("/home/unittest/dicodile/dicodile"))
assert get_data_home() == Path("/home/unittest/dicodile/dicodile")


def _set_env(monkeypatch, d):
Expand Down
3 changes: 2 additions & 1 deletion dicodile/utils/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def compute_n_seg(self):
self.n_seg_per_axis = []
for size_ax, size_seg_ax in zip(self.signal_support, self.seg_support):
# Make sure that n_seg_ax is of type int (and not np.int*)
n_seg_ax = max(1, int(size_ax // size_seg_ax))
n_seg_ax = max(1, int(size_ax // size_seg_ax) +
((size_ax % size_seg_ax) != 0))
self.n_seg_per_axis.append(n_seg_ax)
self.effective_n_seg *= n_seg_ax

Expand Down
26 changes: 23 additions & 3 deletions dicodile/utils/tests/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def test_inner_coordinate():

if w_rank == n_seg[1] - 1:
assert segments.is_contained_coordinate(
i_seg,
(seg_support[0] - overlap[0] - 1, seg_support[1] - 1),
inner=True)
i_seg,
(seg_support[0] - overlap[0] - 1, seg_support[1] - 1),
inner=True)
else:
assert not segments.is_contained_coordinate(
i_seg, (seg_support[0] - overlap[0] - 1,
Expand Down Expand Up @@ -267,3 +267,23 @@ def test_padding_to_overlap():
overlap = seg.get_padding_to_overlap(i_seg)
z = np.pad(z, overlap, mode='constant')
assert z.shape == seg_support_all


def test_segments():
"""Tests if the number of segments is computed correctly."""
seg_support = [9]
inner_bounds = [[0, 252]]
full_support = (252,)

seg = Segmentation(n_seg=None, seg_support=seg_support,
inner_bounds=inner_bounds, full_support=full_support)
seg.compute_n_seg()

assert seg.effective_n_seg == 28

seg_support = [10]
seg = Segmentation(n_seg=None, seg_support=seg_support,
inner_bounds=inner_bounds, full_support=full_support)
seg.compute_n_seg()

assert seg.effective_n_seg == 26
5 changes: 3 additions & 2 deletions dicodile/workers/dicod_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def compute_z_hat(self):
deadline = t_start + self.timeout
else:
deadline = None

for ii in range(self.max_iter):
# Display the progress of the algorithm
self.progress(ii, max_ii=self.max_iter, unit="iterations",
Expand All @@ -110,7 +111,6 @@ def compute_z_hat(self):
t_run += selection_duration
else:
k0, pt0, dz = None, None, 0

# update the accumulator for 'random' strategy
accumulator = max(abs(dz), accumulator)

Expand Down Expand Up @@ -249,7 +249,7 @@ def init_cd_variables(self):
constants['DtD'],
n_atoms,
atom_support
)
)
self.constants = constants

# List of all pending messages sent
Expand Down Expand Up @@ -764,6 +764,7 @@ def recv_signal(self):
inner_bounds=inner_bounds,
full_support=worker_support)

self.max_iter *= self.local_segments.effective_n_seg
self.synchronize_workers(with_main=True)

return X_worker, z0
Expand Down

0 comments on commit 5b9a943

Please sign in to comment.