Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move stitching code to prediction module #21

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 10 additions & 277 deletions deepsensor/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Prediction,
increase_spatial_resolution,
infer_prediction_modality_from_X_t,
stitch_clipped_predictions,
)
from deepsensor.data.task import Task

Expand Down Expand Up @@ -672,7 +673,7 @@ def predict_patchwise(
List of tasks containing context data. Tasks for patchwise prediction must be generated by a task loader using the "sliding" patching strategy.
X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`):
Target locations to predict at. Can be an xarray object
containingon-grid locations or a pandas object containing off-grid locations.
containing on-grid locations or a pandas object containing off-grid locations.
X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional
2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated
to the same grid as ``X_t`` and patched in the same way. Default None (no mask).
Expand Down Expand Up @@ -810,281 +811,6 @@ def overlap_index(
),
)

def get_coordinate_extent(
ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend: bool
) -> tuple:
"""Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions.

Parameters
----------
ds : Data object
The dataset or data array to determine coordinate extent for.

x1_ascend : bool
Whether the x1 coordinates ascend (increase) from top to bottom.

x2_ascend : bool
Whether the x2 coordinates ascend (increase) from left to right.

Returns:
-------
tuple of tuples:
Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)).
"""
if x1_ascend:
ds_x1_coords = (
ds.coords[orig_x1_name].min().values,
ds.coords[orig_x1_name].max().values,
)
else:
ds_x1_coords = (
ds.coords[orig_x1_name].max().values,
ds.coords[orig_x1_name].min().values,
)
if x2_ascend:
ds_x2_coords = (
ds.coords[orig_x2_name].min().values,
ds.coords[orig_x2_name].max().values,
)
else:
ds_x2_coords = (
ds.coords[orig_x2_name].max().values,
ds.coords[orig_x2_name].min().values,
)
return ds_x1_coords, ds_x2_coords

def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]:
"""Convert coordinates into pixel row/column (index).

Parameters
----------
args : tuple
If one argument (numeric), it represents the coordinate value.
If two arguments (lists), they represent lists of coordinate values.

x1 : bool, optional
If True, compute index for x1 (default is True).

Returns:
-------
Union[int, Tuple[List[int], List[int]]]
If one argument is provided and x1 is True or False, returns the index position.
If two arguments are provided, returns a tuple containing two lists:
- First list: indices corresponding to x1 coordinates.
- Second list: indices corresponding to x2 coordinates.

"""
if len(args) == 1:
patch_coord = args
if x1:
coord_index = np.argmin(
np.abs(X_t.coords[orig_x1_name].values - patch_coord)
)
else:
coord_index = np.argmin(
np.abs(X_t.coords[orig_x2_name].values - patch_coord)
)
return coord_index

elif len(args) == 2:
patch_x1, patch_x2 = args
x1_index = [
np.argmin(np.abs(X_t.coords[orig_x1_name].values - target_x1))
for target_x1 in patch_x1
]
x2_index = [
np.argmin(np.abs(X_t.coords[orig_x2_name].values - target_x2))
for target_x2 in patch_x2
]
return (x1_index, x2_index)

def stitch_clipped_predictions(
patch_preds: List[Prediction],
patch_overlap: int,
patches_per_row: int,
x1_ascend: bool = True,
x2_ascend: bool = True,
) -> Prediction:
"""Stitch patchwise predictions to form prediction at original extent.

Parameters
----------
patch_preds : list (class:`~.model.pred.Prediction`)
List of patchwise predictions

patch_overlap: int
Overlap between adjacent patches in pixels.

patches_per_row: int
Number of patchwise predictions in each row.

x1_ascend : bool
Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True.

x2_ascend : bool
Boolean defining whether the x2 coords ascend (increase) from left to right, default = True.

Returns:
-------
combined: dict
Dictionary object containing the stitched model predictions.
"""
# Get row/col index values of X_t.
data_x1_coords, data_x2_coords = get_coordinate_extent(
X_t, x1_ascend, x2_ascend
)
data_x1_index, data_x2_index = get_index(data_x1_coords, data_x2_coords)

# Iterate through patchwise predictions and slice edges prior to stitchin.
patches_clipped = []
for i, patch_pred in enumerate(patch_preds):
# get one variable name to use for coordinates and extent
first_key = list(patch_pred.keys())[0]
# Get row/col index values of each patch.
patch_x1_coords, patch_x2_coords = get_coordinate_extent(
patch_pred[first_key], x1_ascend, x2_ascend
)
patch_x1_index, patch_x2_index = get_index(
patch_x1_coords, patch_x2_coords
)

# Calculate size of border to slice of each edge of patchwise predictions.
# Initially set the size of all borders to the size of the overlap.
b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0]
b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1]

# Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
if patch_x2_index[0] == data_x2_index[0]:
b_x2_min = 0
b_x2_max = b_x2_max

# At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
elif patch_x2_index[1] == data_x2_index[1]:
b_x2_max = 0
patch_row_prev = patch_preds[i - 1]

# If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
if x2_ascend:
prev_patch_x2_max = get_index(
patch_row_prev[first_key].coords[orig_x2_name].max(),
x1=False,
)
b_x2_min = (
prev_patch_x2_max - patch_x2_index[0]
) - patch_overlap[1]

# If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
else:
prev_patch_x2_min = get_index(
patch_row_prev[first_key].coords[orig_x2_name].min(),
x1=False,
)
b_x2_min = (
patch_x2_index[0] - prev_patch_x2_min
) - patch_overlap[1]
else:
b_x2_max = b_x2_max

# Repeat process as above for x1 coordinates.
if patch_x1_index[0] == data_x1_index[0]:
b_x1_min = 0

elif abs(patch_x1_index[1] - data_x1_index[1]) < 2:
b_x1_max = 0
b_x1_max = b_x1_max
patch_prev = patch_preds[i - patches_per_row]
if x1_ascend:
prev_patch_x1_max = get_index(
patch_prev[first_key].coords[orig_x1_name].max(),
x1=True,
)
b_x1_min = (
prev_patch_x1_max - patch_x1_index[0]
) - patch_overlap[0]
else:
prev_patch_x1_min = get_index(
patch_prev[first_key].coords[orig_x1_name].min(),
x1=True,
)

b_x1_min = (
prev_patch_x1_min - patch_x1_index[0]
) - patch_overlap[0]
else:
b_x1_max = b_x1_max

patch_clip_x1_min = int(b_x1_min)
patch_clip_x1_max = int(
patch_pred[first_key].sizes[orig_x1_name] - b_x1_max
)
patch_clip_x2_min = int(b_x2_min)
patch_clip_x2_max = int(
patch_pred[first_key].sizes[orig_x2_name] - b_x2_max
)

# Define slicing parameters
slicing_params = {
orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max),
orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max),
}

# Slice patchwise predictions
patch_clip = {
key: dataset.isel(**slicing_params)
for key, dataset in patch_pred.items()
}

patches_clipped.append(patch_clip)

# Create blank prediction object to stitch prediction values onto.
stitched_prediction = copy.deepcopy(patch_preds[0])
# Set prediction object extent to the same as X_t.
for var_name, data_array in stitched_prediction.items():
blank_ds = xr.Dataset(
coords={
orig_x1_name: X_t[orig_x1_name],
orig_x2_name: X_t[orig_x2_name],
"time": stitched_prediction[0]["time"],
}
)

# Set data variable names e.g. mean, std to those in patched prediction. Make all values Nan.
for data_var in data_array.data_vars:
blank_ds[data_var] = data_array[data_var]
blank_ds[data_var][:] = np.nan
stitched_prediction[var_name] = blank_ds

# Restructure prediction objects for merging
restructured_patches = {
key: [item[key] for item in patches_clipped]
for key in patches_clipped[0].keys()
}

# Merge patchwise predictions to create final stiched prediction.
# Iterate over each variable (key) in the prediction dictionary
for var_name, patches in restructured_patches.items():
# Retrieve the blank dataset for the current variable
prediction_array = stitched_prediction[var_name]

# Merge each patch into the combined dataset
for patch in patches:
for var in patch.data_vars:
# Reindex the patch to catch any slight rounding errors and misalignment with the combined dataset
reindexed_patch = patch[var].reindex_like(
prediction_array[var], method="nearest", tolerance=1e-6
)

# Combine data, prioritizing non-NaN values from patches
prediction_array[var] = prediction_array[var].where(
np.isnan(reindexed_patch), reindexed_patch
)

# Update the dictionary with the merged dataset
stitched_prediction[var_name] = prediction_array
return stitched_prediction

# load patch_size and stride from task
patch_size = tasks[0]["patch_size"]
stride = tasks[0]["stride"]
Expand Down Expand Up @@ -1179,7 +905,14 @@ def stitch_clipped_predictions(

patches_per_row = get_patches_per_row(preds)
prediction = stitch_clipped_predictions(
preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending
preds,
patch_overlap_unnorm,
patches_per_row,
X_t,
orig_x1_name,
orig_x2_name,
x1_ascending,
x2_ascending,
)

return prediction
Expand Down
Loading
Loading