diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 07426968..1f393c27 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -9,6 +9,7 @@ Prediction, increase_spatial_resolution, infer_prediction_modality_from_X_t, + stitch_clipped_predictions, ) from deepsensor.data.task import Task @@ -672,7 +673,7 @@ def predict_patchwise( List of tasks containing context data. Tasks for patchwise prediction must be generated by a task loader using the "sliding" patching strategy. X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): Target locations to predict at. Can be an xarray object - containingon-grid locations or a pandas object containing off-grid locations. + containing on-grid locations or a pandas object containing off-grid locations. X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional 2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated to the same grid as ``X_t`` and patched in the same way. Default None (no mask). @@ -810,281 +811,6 @@ def overlap_index( ), ) - def get_coordinate_extent( - ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend: bool - ) -> tuple: - """Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions. - - Parameters - ---------- - ds : Data object - The dataset or data array to determine coordinate extent for. - - x1_ascend : bool - Whether the x1 coordinates ascend (increase) from top to bottom. - - x2_ascend : bool - Whether the x2 coordinates ascend (increase) from left to right. - - Returns: - ------- - tuple of tuples: - Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)). - """ - if x1_ascend: - ds_x1_coords = ( - ds.coords[orig_x1_name].min().values, - ds.coords[orig_x1_name].max().values, - ) - else: - ds_x1_coords = ( - ds.coords[orig_x1_name].max().values, - ds.coords[orig_x1_name].min().values, - ) - if x2_ascend: - ds_x2_coords = ( - ds.coords[orig_x2_name].min().values, - ds.coords[orig_x2_name].max().values, - ) - else: - ds_x2_coords = ( - ds.coords[orig_x2_name].max().values, - ds.coords[orig_x2_name].min().values, - ) - return ds_x1_coords, ds_x2_coords - - def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]: - """Convert coordinates into pixel row/column (index). - - Parameters - ---------- - args : tuple - If one argument (numeric), it represents the coordinate value. - If two arguments (lists), they represent lists of coordinate values. - - x1 : bool, optional - If True, compute index for x1 (default is True). - - Returns: - ------- - Union[int, Tuple[List[int], List[int]]] - If one argument is provided and x1 is True or False, returns the index position. - If two arguments are provided, returns a tuple containing two lists: - - First list: indices corresponding to x1 coordinates. - - Second list: indices corresponding to x2 coordinates. - - """ - if len(args) == 1: - patch_coord = args - if x1: - coord_index = np.argmin( - np.abs(X_t.coords[orig_x1_name].values - patch_coord) - ) - else: - coord_index = np.argmin( - np.abs(X_t.coords[orig_x2_name].values - patch_coord) - ) - return coord_index - - elif len(args) == 2: - patch_x1, patch_x2 = args - x1_index = [ - np.argmin(np.abs(X_t.coords[orig_x1_name].values - target_x1)) - for target_x1 in patch_x1 - ] - x2_index = [ - np.argmin(np.abs(X_t.coords[orig_x2_name].values - target_x2)) - for target_x2 in patch_x2 - ] - return (x1_index, x2_index) - - def stitch_clipped_predictions( - patch_preds: List[Prediction], - patch_overlap: int, - patches_per_row: int, - x1_ascend: bool = True, - x2_ascend: bool = True, - ) -> Prediction: - """Stitch patchwise predictions to form prediction at original extent. - - Parameters - ---------- - patch_preds : list (class:`~.model.pred.Prediction`) - List of patchwise predictions - - patch_overlap: int - Overlap between adjacent patches in pixels. - - patches_per_row: int - Number of patchwise predictions in each row. - - x1_ascend : bool - Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. - - x2_ascend : bool - Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. - - Returns: - ------- - combined: dict - Dictionary object containing the stitched model predictions. - """ - # Get row/col index values of X_t. - data_x1_coords, data_x2_coords = get_coordinate_extent( - X_t, x1_ascend, x2_ascend - ) - data_x1_index, data_x2_index = get_index(data_x1_coords, data_x2_coords) - - # Iterate through patchwise predictions and slice edges prior to stitchin. - patches_clipped = [] - for i, patch_pred in enumerate(patch_preds): - # get one variable name to use for coordinates and extent - first_key = list(patch_pred.keys())[0] - # Get row/col index values of each patch. - patch_x1_coords, patch_x2_coords = get_coordinate_extent( - patch_pred[first_key], x1_ascend, x2_ascend - ) - patch_x1_index, patch_x2_index = get_index( - patch_x1_coords, patch_x2_coords - ) - - # Calculate size of border to slice of each edge of patchwise predictions. - # Initially set the size of all borders to the size of the overlap. - b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] - b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] - - # Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column. - if patch_x2_index[0] == data_x2_index[0]: - b_x2_min = 0 - b_x2_max = b_x2_max - - # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch. - elif patch_x2_index[1] == data_x2_index[1]: - b_x2_max = 0 - patch_row_prev = patch_preds[i - 1] - - # If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels. - # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels - if x2_ascend: - prev_patch_x2_max = get_index( - patch_row_prev[first_key].coords[orig_x2_name].max(), - x1=False, - ) - b_x2_min = ( - prev_patch_x2_max - patch_x2_index[0] - ) - patch_overlap[1] - - # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels. - # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels - else: - prev_patch_x2_min = get_index( - patch_row_prev[first_key].coords[orig_x2_name].min(), - x1=False, - ) - b_x2_min = ( - patch_x2_index[0] - prev_patch_x2_min - ) - patch_overlap[1] - else: - b_x2_max = b_x2_max - - # Repeat process as above for x1 coordinates. - if patch_x1_index[0] == data_x1_index[0]: - b_x1_min = 0 - - elif abs(patch_x1_index[1] - data_x1_index[1]) < 2: - b_x1_max = 0 - b_x1_max = b_x1_max - patch_prev = patch_preds[i - patches_per_row] - if x1_ascend: - prev_patch_x1_max = get_index( - patch_prev[first_key].coords[orig_x1_name].max(), - x1=True, - ) - b_x1_min = ( - prev_patch_x1_max - patch_x1_index[0] - ) - patch_overlap[0] - else: - prev_patch_x1_min = get_index( - patch_prev[first_key].coords[orig_x1_name].min(), - x1=True, - ) - - b_x1_min = ( - prev_patch_x1_min - patch_x1_index[0] - ) - patch_overlap[0] - else: - b_x1_max = b_x1_max - - patch_clip_x1_min = int(b_x1_min) - patch_clip_x1_max = int( - patch_pred[first_key].sizes[orig_x1_name] - b_x1_max - ) - patch_clip_x2_min = int(b_x2_min) - patch_clip_x2_max = int( - patch_pred[first_key].sizes[orig_x2_name] - b_x2_max - ) - - # Define slicing parameters - slicing_params = { - orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), - orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max), - } - - # Slice patchwise predictions - patch_clip = { - key: dataset.isel(**slicing_params) - for key, dataset in patch_pred.items() - } - - patches_clipped.append(patch_clip) - - # Create blank prediction object to stitch prediction values onto. - stitched_prediction = copy.deepcopy(patch_preds[0]) - # Set prediction object extent to the same as X_t. - for var_name, data_array in stitched_prediction.items(): - blank_ds = xr.Dataset( - coords={ - orig_x1_name: X_t[orig_x1_name], - orig_x2_name: X_t[orig_x2_name], - "time": stitched_prediction[0]["time"], - } - ) - - # Set data variable names e.g. mean, std to those in patched prediction. Make all values Nan. - for data_var in data_array.data_vars: - blank_ds[data_var] = data_array[data_var] - blank_ds[data_var][:] = np.nan - stitched_prediction[var_name] = blank_ds - - # Restructure prediction objects for merging - restructured_patches = { - key: [item[key] for item in patches_clipped] - for key in patches_clipped[0].keys() - } - - # Merge patchwise predictions to create final stiched prediction. - # Iterate over each variable (key) in the prediction dictionary - for var_name, patches in restructured_patches.items(): - # Retrieve the blank dataset for the current variable - prediction_array = stitched_prediction[var_name] - - # Merge each patch into the combined dataset - for patch in patches: - for var in patch.data_vars: - # Reindex the patch to catch any slight rounding errors and misalignment with the combined dataset - reindexed_patch = patch[var].reindex_like( - prediction_array[var], method="nearest", tolerance=1e-6 - ) - - # Combine data, prioritizing non-NaN values from patches - prediction_array[var] = prediction_array[var].where( - np.isnan(reindexed_patch), reindexed_patch - ) - - # Update the dictionary with the merged dataset - stitched_prediction[var_name] = prediction_array - return stitched_prediction - # load patch_size and stride from task patch_size = tasks[0]["patch_size"] stride = tasks[0]["stride"] @@ -1179,7 +905,14 @@ def stitch_clipped_predictions( patches_per_row = get_patches_per_row(preds) prediction = stitch_clipped_predictions( - preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending + preds, + patch_overlap_unnorm, + patches_per_row, + X_t, + orig_x1_name, + orig_x2_name, + x1_ascending, + x2_ascending, ) return prediction diff --git a/deepsensor/model/pred.py b/deepsensor/model/pred.py index 4c560a77..9c7dc247 100644 --- a/deepsensor/model/pred.py +++ b/deepsensor/model/pred.py @@ -1,4 +1,5 @@ -from typing import Union, List, Optional +import copy +from typing import Union, List, Optional, Tuple import numpy as np import pandas as pd @@ -364,3 +365,340 @@ def infer_prediction_modality_from_X_t( f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}." ) return mode + + +def _get_coordinate_extent( + ds: Union[xr.DataArray, xr.Dataset], + orig_x1_name: str, + orig_x2_name: str, + x1_ascend: bool, + x2_ascend: bool, +) -> Tuple: + """Get coordinate extent of dataset. + Coordinate extent is defined as maximum and minimum value of x1 and x2. + + Parameters + ---------- + ds : Data object + The dataset or data array to determine coordinate extent for. + + x1_ascend : bool + Whether the x1 coordinates ascend (increase) from top to bottom. + + x2_ascend : bool + Whether the x2 coordinates ascend (increase) from left to right. + + Returns: + ------- + tuple of tuples: + Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)). + """ + if x1_ascend: + ds_x1_coords = ( + ds.coords[orig_x1_name].min().values, + ds.coords[orig_x1_name].max().values, + ) + else: + ds_x1_coords = ( + ds.coords[orig_x1_name].max().values, + ds.coords[orig_x1_name].min().values, + ) + if x2_ascend: + ds_x2_coords = ( + ds.coords[orig_x2_name].min().values, + ds.coords[orig_x2_name].max().values, + ) + else: + ds_x2_coords = ( + ds.coords[orig_x2_name].max().values, + ds.coords[orig_x2_name].min().values, + ) + return ds_x1_coords, ds_x2_coords + + +def _get_index( + *args, + X_t: Union[ + xr.Dataset, + xr.DataArray, + pd.DataFrame, + pd.Series, + pd.Index, + np.ndarray, + ], + orig_x1_name: str, + orig_x2_name: str, + x1: bool = True, +) -> Union[int, Tuple[List[int], List[int]]]: + """Convert coordinates into pixel row/column (index). + + Parameters + ---------- + args : tuple + If one argument (numeric), it represents the coordinate value. + If two arguments (lists), they represent lists of coordinate values. + + X_t : (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`) + Target locations to predict at. Can be an xarray object + containing on-grid locations or a pandas object containing off-grid locations. + + x1 : bool, optional + If True, compute index for x1 (default is True). + + Returns: + ------- + Union[int, Tuple[List[int], List[int]]] + If one argument is provided and x1 is True or False, returns the index position. + If two arguments are provided, returns a tuple containing two lists: + - First list: indices corresponding to x1 coordinates. + - Second list: indices corresponding to x2 coordinates. + + """ + if len(args) == 1: + patch_coord = args + if x1: + coord_index = np.argmin( + np.abs(X_t.coords[orig_x1_name].values - patch_coord) + ) + else: + coord_index = np.argmin( + np.abs(X_t.coords[orig_x2_name].values - patch_coord) + ) + return coord_index + + elif len(args) == 2: + patch_x1, patch_x2 = args + x1_index = [ + np.argmin(np.abs(X_t.coords[orig_x1_name].values - target_x1)) + for target_x1 in patch_x1 + ] + x2_index = [ + np.argmin(np.abs(X_t.coords[orig_x2_name].values - target_x2)) + for target_x2 in patch_x2 + ] + return (x1_index, x2_index) + + +def stitch_clipped_predictions( + patch_preds: List[Prediction], + patch_overlap: int, + patches_per_row: int, + X_t: Union[ + xr.Dataset, + xr.DataArray, + pd.DataFrame, + pd.Series, + pd.Index, + np.ndarray, + ], + orig_x1_name: str, + orig_x2_name: str, + x1_ascend: bool = True, + x2_ascend: bool = True, +) -> Prediction: + """Stitch patchwise predictions to form prediction at original extent of X_t. + + Parameters + ---------- + patch_preds : list (class:`~.model.pred.Prediction`) + List of patchwise predictions + + patch_overlap: int + Overlap between adjacent patches in pixels. + + patches_per_row: int + Number of patchwise predictions in each row. + + X_t : (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`) + Target locations to predict at. Can be an xarray object + containing on-grid locations or a pandas object containing off-grid locations. + + orig_x1_name : str + x1 coordinate names of original unnormalised dataset + + orig_x2_name : str + x2 coordinate names of original unnormalised dataset + + x1_ascend : bool + Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. + + x2_ascend : bool + Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. + + Returns: + ------- + combined: dict + Dictionary object containing the stitched model predictions. + """ + # Get row/col index values of X_t. + data_x1_coords, data_x2_coords = _get_coordinate_extent( + X_t, + orig_x1_name=orig_x1_name, + orig_x2_name=orig_x2_name, + x1_ascend=x1_ascend, + x2_ascend=x2_ascend, + ) + data_x1_index, data_x2_index = _get_index( + data_x1_coords, + data_x2_coords, + X_t=X_t, + orig_x1_name=orig_x1_name, + orig_x2_name=orig_x2_name, + ) + + # Iterate through patchwise predictions and slice edges prior to stitchin. + patches_clipped = [] + for i, patch_pred in enumerate(patch_preds): + # Get one variable name to use for coordinates and extent + first_key = list(patch_pred.keys())[0] + # Get row/col index values of each patch. + patch_x1_coords, patch_x2_coords = _get_coordinate_extent( + patch_pred[first_key], + orig_x1_name=orig_x1_name, + orig_x2_name=orig_x2_name, + x1_ascend=x1_ascend, + x2_ascend=x2_ascend, + ) + patch_x1_index, patch_x2_index = _get_index( + patch_x1_coords, + patch_x2_coords, + X_t=X_t, + orig_x1_name=orig_x1_name, + orig_x2_name=orig_x2_name, + ) + + # Calculate size of border to slice off each edge of patchwise predictions. + # Initially set the size of all borders to the size of the overlap. + b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] + b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] + + # Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column. + if patch_x2_index[0] == data_x2_index[0]: + b_x2_min = 0 + b_x2_max = b_x2_max + + # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch. + elif patch_x2_index[1] == data_x2_index[1]: + b_x2_max = 0 + patch_row_prev = patch_preds[i - 1] + + # If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels. + # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels + if x2_ascend: + prev_patch_x2_max = _get_index( + patch_row_prev[first_key].coords[orig_x2_name].max(), + X_t=X_t, + orig_x1_name=orig_x1_name, + orig_x2_name=orig_x2_name, + x1=False, + ) + b_x2_min = (prev_patch_x2_max - patch_x2_index[0]) - patch_overlap[1] + + # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels. + # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels + else: + prev_patch_x2_min = _get_index( + patch_row_prev[first_key].coords[orig_x2_name].min(), + X_t=X_t, + orig_x1_name=orig_x1_name, + orig_x2_name=orig_x2_name, + x1=False, + ) + b_x2_min = (patch_x2_index[0] - prev_patch_x2_min) - patch_overlap[1] + else: + b_x2_max = b_x2_max + + # Repeat process as above for x1 coordinates. + if patch_x1_index[0] == data_x1_index[0]: + b_x1_min = 0 + + elif abs(patch_x1_index[1] - data_x1_index[1]) < 2: + b_x1_max = 0 + b_x1_max = b_x1_max + patch_prev = patch_preds[i - patches_per_row] + if x1_ascend: + prev_patch_x1_max = _get_index( + patch_prev[first_key].coords[orig_x1_name].max(), + X_t=X_t, + orig_x1_name=orig_x1_name, + orig_x2_name=orig_x2_name, + x1=True, + ) + b_x1_min = (prev_patch_x1_max - patch_x1_index[0]) - patch_overlap[0] + else: + prev_patch_x1_min = _get_index( + patch_prev[first_key].coords[orig_x1_name].min(), + X_t=X_t, + orig_x1_name=orig_x1_name, + orig_x2_name=orig_x2_name, + x1=True, + ) + + b_x1_min = (prev_patch_x1_min - patch_x1_index[0]) - patch_overlap[0] + else: + b_x1_max = b_x1_max + + patch_clip_x1_min = int(b_x1_min) + patch_clip_x1_max = int(patch_pred[first_key].sizes[orig_x1_name] - b_x1_max) + patch_clip_x2_min = int(b_x2_min) + patch_clip_x2_max = int(patch_pred[first_key].sizes[orig_x2_name] - b_x2_max) + + # Define slicing parameters + slicing_params = { + orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), + orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max), + } + + # Slice patchwise predictions + patch_clip = { + key: dataset.isel(**slicing_params) for key, dataset in patch_pred.items() + } + + patches_clipped.append(patch_clip) + + # Create blank prediction object to stitch prediction values onto. + stitched_prediction = copy.deepcopy(patch_preds[0]) + # Set prediction object extent to the same as X_t. + for var_name, data_array in stitched_prediction.items(): + blank_ds = xr.Dataset( + coords={ + orig_x1_name: X_t[orig_x1_name], + orig_x2_name: X_t[orig_x2_name], + "time": stitched_prediction[0]["time"], + } + ) + + # Set data variable names e.g. mean, std to those in patched prediction. Make all values Nan. + for data_var in data_array.data_vars: + blank_ds[data_var] = data_array[data_var] + blank_ds[data_var][:] = np.nan + stitched_prediction[var_name] = blank_ds + + # Restructure prediction objects for merging + restructured_patches = { + key: [item[key] for item in patches_clipped] + for key in patches_clipped[0].keys() + } + + # Merge patchwise predictions to create final stiched prediction. + # Iterate over each variable (key) in the prediction dictionary + for var_name, patches in restructured_patches.items(): + # Retrieve the blank dataset for the current variable + prediction_array = stitched_prediction[var_name] + + # Merge each patch into the combined dataset + for patch in patches: + for var in patch.data_vars: + # Reindex the patch to catch any slight rounding errors and misalignment with the combined dataset + reindexed_patch = patch[var].reindex_like( + prediction_array[var], method="nearest", tolerance=1e-6 + ) + + # Combine data, prioritizing non-NaN values from patches + prediction_array[var] = prediction_array[var].where( + np.isnan(reindexed_patch), reindexed_patch + ) + + # Update the dictionary with the merged dataset + stitched_prediction[var_name] = prediction_array + return stitched_prediction