Skip to content

Commit

Permalink
address montonic and prediction size issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Rogers committed Sep 17, 2024
1 parent 6f0e2e6 commit c0cd17e
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions deepsensor/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,29 +918,38 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a
"""
if patch_x2_index[0] == data_x2_index[0]:
b_x2_min = 0
# This line, as well as 933, 940 and 950 address the different shapes between the input and prediction
# TODO: Try to resolve this issue in data/loader.py by ensuring patches are perfectly square.
b_x2_max = b_x2_max + 1
elif patch_x2_index[1] == data_x2_index[1]:
b_x2_max = 0
patch_row_prev = preds[i-1]
if x2_ascend:
prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[orig_x2_name].max()), x1 = False)
prev_patch_x2_max = get_index(patch_row_prev[var_name].coords[orig_x2_name].max(), x1 = False)
b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1]
else:
prev_patch_x2_min = get_index(int(patch_row_prev[var_name].coords[orig_x2_name].min()), x1 = False)
prev_patch_x2_min = get_index(patch_row_prev[var_name].coords[orig_x2_name].min(), x1 = False)
b_x2_min = (patch_x2_index[0] -prev_patch_x2_min)-patch_overlap[1]

else:
b_x2_max = b_x2_max + 1


if patch_x1_index[0] == data_x1_index[0]:
b_x1_min = 0
# TODO: ensure this elif statement is robust to multiple patch sizes.
elif abs(patch_x1_index[1] - data_x1_index[1]) < 2:
b_x1_max = 0
b_x1_max = b_x1_max + 1
patch_prev = preds[i-patches_per_row]
if x1_ascend:
prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[orig_x1_name].max()), x1 = True)
prev_patch_x1_max = get_index(patch_prev[var_name].coords[orig_x1_name].max(), x1 = True)
b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0]
else:
prev_patch_x1_min = get_index(int(patch_prev[var_name].coords[orig_x1_name].min()), x1 = True)
prev_patch_x1_min = get_index(patch_prev[var_name].coords[orig_x1_name].min(), x1 = True)

b_x1_min = (prev_patch_x1_min- patch_x1_index[0])- patch_overlap[0]

else:
b_x1_max = b_x1_max + 1

patch_clip_x1_min = int(b_x1_min)
patch_clip_x1_max = int(data_array.sizes[orig_x1_name] - b_x1_max)
Expand Down

0 comments on commit c0cd17e

Please sign in to comment.