From c0cd17e5eec37983e97dbe56020944d3e83a1f0f Mon Sep 17 00:00:00 2001 From: Martin Rogers Date: Tue, 17 Sep 2024 14:53:25 +0100 Subject: [PATCH] address montonic and prediction size issues --- deepsensor/model/model.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 566f21f3..a2ac84eb 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -918,29 +918,38 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a """ if patch_x2_index[0] == data_x2_index[0]: b_x2_min = 0 + # This line, as well as 933, 940 and 950 address the different shapes between the input and prediction + # TODO: Try to resolve this issue in data/loader.py by ensuring patches are perfectly square. + b_x2_max = b_x2_max + 1 elif patch_x2_index[1] == data_x2_index[1]: b_x2_max = 0 patch_row_prev = preds[i-1] if x2_ascend: - prev_patch_x2_max = get_index(int(patch_row_prev[var_name].coords[orig_x2_name].max()), x1 = False) + prev_patch_x2_max = get_index(patch_row_prev[var_name].coords[orig_x2_name].max(), x1 = False) b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] else: - prev_patch_x2_min = get_index(int(patch_row_prev[var_name].coords[orig_x2_name].min()), x1 = False) + prev_patch_x2_min = get_index(patch_row_prev[var_name].coords[orig_x2_name].min(), x1 = False) b_x2_min = (patch_x2_index[0] -prev_patch_x2_min)-patch_overlap[1] - + else: + b_x2_max = b_x2_max + 1 + + if patch_x1_index[0] == data_x1_index[0]: b_x1_min = 0 + # TODO: ensure this elif statement is robust to multiple patch sizes. elif abs(patch_x1_index[1] - data_x1_index[1]) < 2: b_x1_max = 0 + b_x1_max = b_x1_max + 1 patch_prev = preds[i-patches_per_row] if x1_ascend: - prev_patch_x1_max = get_index(int(patch_prev[var_name].coords[orig_x1_name].max()), x1 = True) + prev_patch_x1_max = get_index(patch_prev[var_name].coords[orig_x1_name].max(), x1 = True) b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] else: - prev_patch_x1_min = get_index(int(patch_prev[var_name].coords[orig_x1_name].min()), x1 = True) + prev_patch_x1_min = get_index(patch_prev[var_name].coords[orig_x1_name].min(), x1 = True) b_x1_min = (prev_patch_x1_min- patch_x1_index[0])- patch_overlap[0] - + else: + b_x1_max = b_x1_max + 1 patch_clip_x1_min = int(b_x1_min) patch_clip_x1_max = int(data_array.sizes[orig_x1_name] - b_x1_max)