diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 659f436c..b09cfb82 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -786,15 +786,14 @@ def sample_offgrid_aux( return Y_t_aux def _compute_global_coordinate_bounds(self) -> List[float]: - """ - Compute global coordinate bounds in order to sample spatial bounds if desired. + """Compute global coordinate bounds in order to sample spatial bounds if desired. - Returns + Returns: ------- bbox: List[float] sequence of global spatial extent as [x1_min, x1_max, x2_min, x2_max] """ - x1_min, x1_max, x2_min, x2_max = np.PINF, np.NINF, np.PINF, np.NINF + x1_min, x1_max, x2_min, x2_max = np.inf, -np.inf, np.inf, -np.inf for var in itertools.chain(self.context, self.target): if isinstance(var, (xr.Dataset, xr.DataArray)): @@ -821,58 +820,52 @@ def _compute_global_coordinate_bounds(self) -> List[float]: x2_max = var_x2_max return [x1_min, x1_max, x2_min, x2_max] - + def _compute_x1x2_direction(self) -> str: - """ - Compute whether the x1 and x2 coords are ascending or descending. + """Compute whether the x1 and x2 coords are ascending or descending. - Returns + Returns: ------- coord_directions: dict(str) Dictionary containing two keys: x1 and x2, with boolean values - defining if these coordings increase or decrease from top left corner. - - """ + defining if these coordings increase or decrease from top left corner. + """ for var in itertools.chain(self.context, self.target): if isinstance(var, (xr.Dataset, xr.DataArray)): - coord_x1_left= var.x1[0] - coord_x1_right= var.x1[-1] - coord_x2_top= var.x2[0] - coord_x2_bottom= var.x2[-1] - + coord_x1_left = var.x1[0] + coord_x1_right = var.x1[-1] + coord_x2_top = var.x2[0] + coord_x2_bottom = var.x2[-1] + x1_ascend = True if coord_x1_left <= coord_x1_right else False x2_ascend = True if coord_x2_top <= coord_x2_bottom else False coord_directions = { - "x1": x1_ascend, - "x2": x2_ascend, - } + "x1": x1_ascend, + "x2": x2_ascend, + } - #TODO- what to input for pd.dataframe + # TODO- what to input for pd.dataframe elif isinstance(var, (pd.DataFrame, pd.Series)): # var_x1_min = var.index.get_level_values("x1").min() # var_x1_max = var.index.get_level_values("x1").max() # var_x2_min = var.index.get_level_values("x2").min() # var_x2_max = var.index.get_level_values("x2").max() - coord_directions = { - "x1": None, - "x2": None - } + coord_directions = {"x1": None, "x2": None} - return coord_directions + return coord_directions def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]: - """ - Sample random window uniformly from global coordinates to slice data. + """Sample random window uniformly from global coordinates to slice data. Parameters ---------- patch_size : Tuple[float] Tuple of window extent - Returns + Returns: ------- bbox: List[float] sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max] @@ -928,8 +921,7 @@ def time_slice_variable(self, var, date, delta_t=0): return var def spatial_slice_variable(self, var, window: List[float]): - """ - Slice a variable by a given window size. + """Slice a variable by a given window size. Args: var (...): @@ -996,8 +988,7 @@ def task_generation( # noqa: D102 datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Task: - """ - Generate a task for a given date. + """Generate a task for a given date. There are several sampling strategies available for the context and target data: @@ -1040,7 +1031,7 @@ def task_generation( # noqa: D102 Override the seed for random sampling. This can be used to use the same random sampling at different ``date``. Default is None. - Returns + Returns: ------- task : :class:`~.data.task.Task` Task object containing the context and target data. @@ -1222,7 +1213,9 @@ def sample_variable(var, sampling_strat, seed): task["time"] = date task["ops"] = [] task["bbox"] = bbox - task["patch_size"] = patch_size # store patch_size and stride in task for use in stitching in prediction + task["patch_size"] = ( + patch_size # store patch_size and stride in task for use in stitching in prediction + ) task["stride"] = stride task["X_c"] = [] task["Y_c"] = [] @@ -1243,7 +1236,6 @@ def sample_variable(var, sampling_strat, seed): for var, delta_t in zip(self.target, self.target_delta_t) ] - # TODO move to method if ( self.links is not None @@ -1363,7 +1355,7 @@ def sample_variable(var, sampling_strat, seed): context_slices[context_idx] = context_var target_slices[target_idx] = target_var - + # check bbox size if bbox is not None: assert ( @@ -1430,19 +1422,19 @@ def sample_variable(var, sampling_strat, seed): def sample_sliding_window( self, patch_size: Tuple[float], stride: Tuple[int] ) -> Sequence[float]: - """ - Sample data using sliding window from global coordinates to slice data. - Parameters + """Sample data using sliding window from global coordinates to slice data. + Parameters. ---------- patch_size : Tuple[float] Tuple of window extent Stride : Tuple[float] Tuple of step size between each patch along x1 and x2 axis. - Returns + + Returns: ------- - bbox: List[float] ## check type of return. - sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max] + bbox: List[float] + Sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max]. """ # define patch size in x1/x2 x1_extend, x2_extend = patch_size @@ -1458,7 +1450,7 @@ def sample_sliding_window( patch_list = [] # Todo: simplify these elif statements - if self.coord_directions['x1'] == False and self.coord_directions['x2'] == True: + if self.coord_directions["x1"] == False and self.coord_directions["x2"] == True: for y in np.arange(x1_max, x1_min, -dy): for x in np.arange(x2_min, x2_max, dx): if y - x1_extend < x1_min: @@ -1474,7 +1466,10 @@ def sample_sliding_window( bbox = [y0 - x1_extend, y0, x0, x0 + x2_extend] patch_list.append(bbox) - elif self.coord_directions['x1'] == False and self.coord_directions['x2'] == False: + elif ( + self.coord_directions["x1"] == False + and self.coord_directions["x2"] == False + ): for y in np.arange(x1_max, x1_min, -dy): for x in np.arange(x2_max, x2_min, -dx): if y - x1_extend < x1_min: @@ -1490,7 +1485,9 @@ def sample_sliding_window( bbox = [y0 - x1_extend, y0, x0 - x2_extend, x0] patch_list.append(bbox) - elif self.coord_directions['x1'] == True and self.coord_directions['x2'] == False: + elif ( + self.coord_directions["x1"] == True and self.coord_directions["x2"] == False + ): for y in np.arange(x1_min, x1_max, dy): for x in np.arange(x2_max, x2_min, -dx): if y + x1_extend > x1_max: @@ -1662,7 +1659,6 @@ def __call__( ) elif patch_strategy == "random": - if patch_size is None: raise ValueError( "Patch size must be specified for random patch sampling" @@ -1751,7 +1747,7 @@ def __call__( datewise_deterministic=datewise_deterministic, seed_override=seed_override, patch_size=patch_size, - stride=stride + stride=stride, ) for bbox in bboxes ] @@ -1768,7 +1764,7 @@ def __call__( datewise_deterministic=datewise_deterministic, seed_override=seed_override, patch_size=patch_size, - stride=stride + stride=stride, ) for bbox in bboxes ] diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 9228ee0b..941117ae 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -680,8 +680,7 @@ def predict_patch( progress_bar: int = 0, verbose: bool = False, ) -> Prediction: - """ - Predict on a regular grid or at off-grid locations. + """Predict on a regular grid or at off-grid locations. Args: tasks (List[Task] | Task): @@ -758,11 +757,11 @@ def predict_patch( # Get coordinate names of original unnormalised dataset. orig_x1_name = data_processor.x1_name orig_x2_name = data_processor.x2_name - + def get_patches_per_row(preds) -> int: - """ - Calculate number of patches per row. + """Calculate number of patches per row. Required to stitch patches back together. + Args: preds (List[class:`~.model.pred.Prediction`]): A list of `dict`-like objects containing patchwise predictions. @@ -773,20 +772,19 @@ def get_patches_per_row(preds) -> int: """ patches_per_row = 0 vars = list(preds[0][0].data_vars) - var = vars[0] - x1_val = preds[0][0][var].coords[orig_x1_name].min() - + var = vars[0] + x1_val = preds[0][0][var].coords[orig_x1_name].min() + for pred in preds: if pred[0][var].coords[orig_x1_name].min() == x1_val: - patches_per_row = patches_per_row + 1 + patches_per_row = patches_per_row + 1 return patches_per_row - - - def get_patch_overlap(overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend) -> int: - """ - Calculate overlap between adjacent patches in pixels. + def get_patch_overlap( + overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend + ) -> int: + """Calculate overlap between adjacent patches in pixels. Parameters ---------- @@ -797,20 +795,20 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend Used for unnormalising the coordinates of the bounding boxes of patches. X_t_ds (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): - Data array containing target locations to predict at. - + Data array containing target locations to predict at. + x1_ascend : str: - Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. - + Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. + x2_ascend : str: - Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. - - Returns + Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. + + Returns: ------- patch_overlap : tuple (int) Unnormalised size of overlap between adjacent patches. """ - # Todo- check if there is simplier and more robust way to convert overlap into pixels. + # Todo- check if there is simplier and more robust way to convert overlap into pixels. # Place x1/x2 overlap values in Xarray to pass into unnormalise() overlap_list = [0, overlap_norm[0], 0, overlap_norm[1]] x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims="x1", name="x1") @@ -825,21 +823,82 @@ def get_patch_overlap(overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend # Find size of overlap for x1/x2 in pixels if x1_ascend: - x1_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x1_name].values - unnorm_overlap_x1))/2))) + x1_overlap_index = int( + np.ceil( + ( + np.argmin( + np.abs( + X_t_ds.coords[orig_x1_name].values + - unnorm_overlap_x1 + ) + ) + / 2 + ) + ) + ) else: - x1_overlap_index = int(np.floor((X_t_ds.coords[orig_x1_name].values.size- int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x1_name].values- unnorm_overlap_x1))))))/2)) + x1_overlap_index = int( + np.floor( + ( + X_t_ds.coords[orig_x1_name].values.size + - int( + np.ceil( + ( + np.argmin( + np.abs( + X_t_ds.coords[orig_x1_name].values + - unnorm_overlap_x1 + ) + ) + ) + ) + ) + ) + / 2 + ) + ) if x2_ascend: - x2_overlap_index = int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x2_name].values - unnorm_overlap_x2))/2))) + x2_overlap_index = int( + np.ceil( + ( + np.argmin( + np.abs( + X_t_ds.coords[orig_x2_name].values + - unnorm_overlap_x2 + ) + ) + / 2 + ) + ) + ) else: - x2_overlap_index = int(np.floor((X_t_ds.coords[orig_x2_name].values.size- int(np.ceil((np.argmin(np.abs(X_t_ds.coords[orig_x2_name].values- unnorm_overlap_x2))))))/2)) + x2_overlap_index = int( + np.floor( + ( + X_t_ds.coords[orig_x2_name].values.size + - int( + np.ceil( + ( + np.argmin( + np.abs( + X_t_ds.coords[orig_x2_name].values + - unnorm_overlap_x2 + ) + ) + ) + ) + ) + ) + / 2 + ) + ) x1_x2_overlap = (x1_overlap_index, x2_overlap_index) return x1_x2_overlap def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]: - """ - Convert coordinates into pixel row/column (index). + """Convert coordinates into pixel row/column (index). Parameters ---------- @@ -850,7 +909,7 @@ def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]: x1 : bool, optional If True, compute index for x1 (default is True). - Returns + Returns: ------- Union[int, Tuple[List[int], List[int]]] If one argument is provided and x1 is True or False, returns the index position. @@ -862,21 +921,31 @@ def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]: if len(args) == 1: patch_coord = args if x1: - coord_index = np.argmin(np.abs(X_t.coords[orig_x1_name].values - patch_coord)) + coord_index = np.argmin( + np.abs(X_t.coords[orig_x1_name].values - patch_coord) + ) else: - coord_index = np.argmin(np.abs(X_t.coords[orig_x2_name].values - patch_coord)) + coord_index = np.argmin( + np.abs(X_t.coords[orig_x2_name].values - patch_coord) + ) return coord_index elif len(args) == 2: - patch_x1, patch_x2 = args - x1_index = [np.argmin(np.abs(X_t.coords[orig_x1_name].values - target_x1)) for target_x1 in patch_x1] - x2_index = [np.argmin(np.abs(X_t.coords[orig_x2_name].values - target_x2)) for target_x2 in patch_x2] + patch_x1, patch_x2 = args + x1_index = [ + np.argmin(np.abs(X_t.coords[orig_x1_name].values - target_x1)) + for target_x1 in patch_x1 + ] + x2_index = [ + np.argmin(np.abs(X_t.coords[orig_x2_name].values - target_x2)) + for target_x2 in patch_x2 + ] return (x1_index, x2_index) - - - def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_ascend=True, x2_ascend=True) -> dict: - """ - Stitch patchwise predictions to form prediction at original extent. + + def stitch_clipped_predictions( + patch_preds, patch_overlap, patches_per_row, x1_ascend=True, x2_ascend=True + ) -> dict: + """Stitch patchwise predictions to form prediction at original extent. Parameters ---------- @@ -888,28 +957,39 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a patches_per_row: int Number of patchwise predictions in each row. - + x1_ascend : str - Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. - + Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. + x2_ascend : str - Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. - - Returns + Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. + + Returns: ------- combined: dict Dictionary object containing the stitched model predictions. """ - # Get row/col index values of X_t. Order depends on whether coordinate is ascending or descending. if x1_ascend: - data_x1 = X_t.coords[orig_x1_name].min().values, X_t.coords[orig_x1_name].max().values - else: - data_x1 = X_t.coords[orig_x1_name].max().values, X_t.coords[orig_x1_name].min().values + data_x1 = ( + X_t.coords[orig_x1_name].min().values, + X_t.coords[orig_x1_name].max().values, + ) + else: + data_x1 = ( + X_t.coords[orig_x1_name].max().values, + X_t.coords[orig_x1_name].min().values, + ) if x2_ascend: - data_x2 = X_t.coords[orig_x2_name].min().values, X_t.coords[orig_x2_name].max().values + data_x2 = ( + X_t.coords[orig_x2_name].min().values, + X_t.coords[orig_x2_name].max().values, + ) else: - data_x2 = X_t.coords[orig_x2_name].max().values, X_t.coords[orig_x2_name].min().values + data_x2 = ( + X_t.coords[orig_x2_name].max().values, + X_t.coords[orig_x2_name].min().values, + ) data_x1_index, data_x2_index = get_index(data_x1, data_x2) patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()} @@ -918,16 +998,28 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a for var_name, data_array in patch_pred.items(): if var_name in patch_pred: # Get row/col index values of each patch. Order depends on whether coordinate is ascending or descending. - if x1_ascend: - patch_x1 = data_array.coords[orig_x1_name].min().values, data_array.coords[orig_x1_name].max().values + if x1_ascend: + patch_x1 = ( + data_array.coords[orig_x1_name].min().values, + data_array.coords[orig_x1_name].max().values, + ) else: - patch_x1 = data_array.coords[orig_x1_name].max().values, data_array.coords[orig_x1_name].min().values + patch_x1 = ( + data_array.coords[orig_x1_name].max().values, + data_array.coords[orig_x1_name].min().values, + ) if x2_ascend: - patch_x2 = data_array.coords[orig_x2_name].min().values, data_array.coords[orig_x2_name].max().values + patch_x2 = ( + data_array.coords[orig_x2_name].min().values, + data_array.coords[orig_x2_name].max().values, + ) else: - patch_x2 = data_array.coords[orig_x2_name].max().values, data_array.coords[orig_x2_name].min().values - patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) - + patch_x2 = ( + data_array.coords[orig_x2_name].max().values, + data_array.coords[orig_x2_name].min().values, + ) + patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2) + b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] @@ -947,46 +1039,76 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a if patch_x2_index[0] == data_x2_index[0]: b_x2_min = 0 # The +1 operations here and elsewhere in this block address the different shapes between the input and prediction - # TODO: Try to resolve this issue in data/loader.py by ensuring patches are perfectly square. + # TODO: Try to resolve this issue in data/loader.py by ensuring patches are perfectly square. b_x2_max = b_x2_max + 1 elif patch_x2_index[1] == data_x2_index[1]: b_x2_max = 0 - patch_row_prev = preds[i-1] + patch_row_prev = preds[i - 1] if x2_ascend: - prev_patch_x2_max = get_index(patch_row_prev[var_name].coords[orig_x2_name].max(), x1 = False) - b_x2_min = (prev_patch_x2_max - patch_x2_index[0])-patch_overlap[1] + prev_patch_x2_max = get_index( + patch_row_prev[var_name].coords[orig_x2_name].max(), + x1=False, + ) + b_x2_min = ( + prev_patch_x2_max - patch_x2_index[0] + ) - patch_overlap[1] else: - prev_patch_x2_min = get_index(patch_row_prev[var_name].coords[orig_x2_name].min(), x1 = False) - b_x2_min = (patch_x2_index[0] -prev_patch_x2_min)-patch_overlap[1] + prev_patch_x2_min = get_index( + patch_row_prev[var_name].coords[orig_x2_name].min(), + x1=False, + ) + b_x2_min = ( + patch_x2_index[0] - prev_patch_x2_min + ) - patch_overlap[1] else: b_x2_max = b_x2_max + 1 - - + if patch_x1_index[0] == data_x1_index[0]: b_x1_min = 0 # TODO: ensure this elif statement is robust to multiple patch sizes. elif abs(patch_x1_index[1] - data_x1_index[1]) < 2: b_x1_max = 0 b_x1_max = b_x1_max + 1 - patch_prev = preds[i-patches_per_row] + patch_prev = preds[i - patches_per_row] if x1_ascend: - prev_patch_x1_max = get_index(patch_prev[var_name].coords[orig_x1_name].max(), x1 = True) - b_x1_min = (prev_patch_x1_max - patch_x1_index[0])- patch_overlap[0] + prev_patch_x1_max = get_index( + patch_prev[var_name].coords[orig_x1_name].max(), + x1=True, + ) + b_x1_min = ( + prev_patch_x1_max - patch_x1_index[0] + ) - patch_overlap[0] else: - prev_patch_x1_min = get_index(patch_prev[var_name].coords[orig_x1_name].min(), x1 = True) + prev_patch_x1_min = get_index( + patch_prev[var_name].coords[orig_x1_name].min(), + x1=True, + ) - b_x1_min = (prev_patch_x1_min- patch_x1_index[0])- patch_overlap[0] + b_x1_min = ( + prev_patch_x1_min - patch_x1_index[0] + ) - patch_overlap[0] else: b_x1_max = b_x1_max + 1 patch_clip_x1_min = int(b_x1_min) - patch_clip_x1_max = int(data_array.sizes[orig_x1_name] - b_x1_max) + patch_clip_x1_max = int( + data_array.sizes[orig_x1_name] - b_x1_max + ) patch_clip_x2_min = int(b_x2_min) - patch_clip_x2_max = int(data_array.sizes[orig_x2_name] - b_x2_max) - - patch_clip = data_array.isel(**{orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), - orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max)}) + patch_clip_x2_max = int( + data_array.sizes[orig_x2_name] - b_x2_max + ) + patch_clip = data_array.isel( + **{ + orig_x1_name: slice( + patch_clip_x1_min, patch_clip_x1_max + ), + orig_x2_name: slice( + patch_clip_x2_min, patch_clip_x2_max + ), + } + ) patches_clipped[var_name].append(patch_clip) @@ -1043,8 +1165,14 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a x2 = xr.DataArray([bbox[2], bbox[3]], dims="x2", name="x2") bbox_norm = xr.Dataset(coords={"x1": x1, "x2": x2}) bbox_unnorm = data_processor.unnormalise(bbox_norm) - unnorm_bbox_x1 = bbox_unnorm[orig_x1_name].values.min(), bbox_unnorm[orig_x1_name].values.max() - unnorm_bbox_x2 = bbox_unnorm[orig_x2_name].values.min(), bbox_unnorm[orig_x2_name].values.max() + unnorm_bbox_x1 = ( + bbox_unnorm[orig_x1_name].values.min(), + bbox_unnorm[orig_x1_name].values.max(), + ) + unnorm_bbox_x2 = ( + bbox_unnorm[orig_x2_name].values.min(), + bbox_unnorm[orig_x2_name].values.max(), + ) # Determine X_t for patch, however, cannot assume min/max ordering of slice coordinates # Check the order of coordinates in X_t, sometimes they are increasing or decreasing in order. @@ -1067,27 +1195,37 @@ def stitch_clipped_predictions(patch_preds, patch_overlap, patches_per_row, x1_a # Determine X_t for patch with correct slice direction task_X_t = X_t.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice}) - + # Patchwise prediction pred = self.predict(task, task_X_t) # Append patchwise DeepSensor prediction object to list preds.append(pred) - overlap_norm = tuple(patch - stride for patch, stride in zip(patch_size, stride)) - patch_overlap_unnorm = get_patch_overlap(overlap_norm, data_processor, X_t, x1_ascending, x2_ascending) + overlap_norm = tuple( + patch - stride for patch, stride in zip(patch_size, stride) + ) + patch_overlap_unnorm = get_patch_overlap( + overlap_norm, data_processor, X_t, x1_ascending, x2_ascending + ) patches_per_row = get_patches_per_row(preds) - stitched_prediction = stitch_clipped_predictions(preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending) - + stitched_prediction = stitch_clipped_predictions( + preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending + ) + ## Cast prediction into DeepSensor.Prediction object. # TODO make this into seperate method. prediction = copy.deepcopy(preds[0]) # Generate new blank DeepSensor.prediction object in original coordinate system. for var_name_copy, data_array_copy in prediction.items(): - # set x and y coords - stitched_preds = xr.Dataset(coords={orig_x1_name: X_t[orig_x1_name], orig_x2_name: X_t[orig_x2_name]}) + stitched_preds = xr.Dataset( + coords={ + orig_x1_name: X_t[orig_x1_name], + orig_x2_name: X_t[orig_x2_name], + } + ) # Set time to same as patched prediction stitched_preds["time"] = data_array_copy["time"]