Skip to content

Commit

Permalink
Merge branch 'patchwise_train' into update_notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
davidwilby committed Nov 5, 2024
2 parents 53f238f + 3a34ed3 commit 9e4254f
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 133 deletions.
92 changes: 44 additions & 48 deletions deepsensor/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,15 +786,14 @@ def sample_offgrid_aux(
return Y_t_aux

def _compute_global_coordinate_bounds(self) -> List[float]:
"""
Compute global coordinate bounds in order to sample spatial bounds if desired.
"""Compute global coordinate bounds in order to sample spatial bounds if desired.
Returns
Returns:
-------
bbox: List[float]
sequence of global spatial extent as [x1_min, x1_max, x2_min, x2_max]
"""
x1_min, x1_max, x2_min, x2_max = np.PINF, np.NINF, np.PINF, np.NINF
x1_min, x1_max, x2_min, x2_max = np.inf, -np.inf, np.inf, -np.inf

for var in itertools.chain(self.context, self.target):
if isinstance(var, (xr.Dataset, xr.DataArray)):
Expand All @@ -821,58 +820,52 @@ def _compute_global_coordinate_bounds(self) -> List[float]:
x2_max = var_x2_max

return [x1_min, x1_max, x2_min, x2_max]

def _compute_x1x2_direction(self) -> str:
"""
Compute whether the x1 and x2 coords are ascending or descending.
"""Compute whether the x1 and x2 coords are ascending or descending.
Returns
Returns:
-------
coord_directions: dict(str)
Dictionary containing two keys: x1 and x2, with boolean values
defining if these coordings increase or decrease from top left corner.
"""
defining if these coordings increase or decrease from top left corner.
"""
for var in itertools.chain(self.context, self.target):
if isinstance(var, (xr.Dataset, xr.DataArray)):
coord_x1_left= var.x1[0]
coord_x1_right= var.x1[-1]
coord_x2_top= var.x2[0]
coord_x2_bottom= var.x2[-1]
coord_x1_left = var.x1[0]
coord_x1_right = var.x1[-1]
coord_x2_top = var.x2[0]
coord_x2_bottom = var.x2[-1]

x1_ascend = True if coord_x1_left <= coord_x1_right else False
x2_ascend = True if coord_x2_top <= coord_x2_bottom else False

coord_directions = {
"x1": x1_ascend,
"x2": x2_ascend,
}
"x1": x1_ascend,
"x2": x2_ascend,
}

#TODO- what to input for pd.dataframe
# TODO- what to input for pd.dataframe
elif isinstance(var, (pd.DataFrame, pd.Series)):
# var_x1_min = var.index.get_level_values("x1").min()
# var_x1_max = var.index.get_level_values("x1").max()
# var_x2_min = var.index.get_level_values("x2").min()
# var_x2_max = var.index.get_level_values("x2").max()

coord_directions = {
"x1": None,
"x2": None
}
coord_directions = {"x1": None, "x2": None}

return coord_directions
return coord_directions

def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]:
"""
Sample random window uniformly from global coordinates to slice data.
"""Sample random window uniformly from global coordinates to slice data.
Parameters
----------
patch_size : Tuple[float]
Tuple of window extent
Returns
Returns:
-------
bbox: List[float]
sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max]
Expand Down Expand Up @@ -928,8 +921,7 @@ def time_slice_variable(self, var, date, delta_t=0):
return var

def spatial_slice_variable(self, var, window: List[float]):
"""
Slice a variable by a given window size.
"""Slice a variable by a given window size.
Args:
var (...):
Expand Down Expand Up @@ -996,8 +988,7 @@ def task_generation( # noqa: D102
datewise_deterministic: bool = False,
seed_override: Optional[int] = None,
) -> Task:
"""
Generate a task for a given date.
"""Generate a task for a given date.
There are several sampling strategies available for the context and
target data:
Expand Down Expand Up @@ -1040,7 +1031,7 @@ def task_generation( # noqa: D102
Override the seed for random sampling. This can be used to use the
same random sampling at different ``date``. Default is None.
Returns
Returns:
-------
task : :class:`~.data.task.Task`
Task object containing the context and target data.
Expand Down Expand Up @@ -1222,7 +1213,9 @@ def sample_variable(var, sampling_strat, seed):
task["time"] = date
task["ops"] = []
task["bbox"] = bbox
task["patch_size"] = patch_size # store patch_size and stride in task for use in stitching in prediction
task["patch_size"] = (
patch_size # store patch_size and stride in task for use in stitching in prediction
)
task["stride"] = stride
task["X_c"] = []
task["Y_c"] = []
Expand All @@ -1243,7 +1236,6 @@ def sample_variable(var, sampling_strat, seed):
for var, delta_t in zip(self.target, self.target_delta_t)
]


# TODO move to method
if (
self.links is not None
Expand Down Expand Up @@ -1363,7 +1355,7 @@ def sample_variable(var, sampling_strat, seed):

context_slices[context_idx] = context_var
target_slices[target_idx] = target_var

# check bbox size
if bbox is not None:
assert (
Expand Down Expand Up @@ -1430,19 +1422,19 @@ def sample_variable(var, sampling_strat, seed):
def sample_sliding_window(
self, patch_size: Tuple[float], stride: Tuple[int]
) -> Sequence[float]:
"""
Sample data using sliding window from global coordinates to slice data.
Parameters
"""Sample data using sliding window from global coordinates to slice data.
Parameters.
----------
patch_size : Tuple[float]
Tuple of window extent
Stride : Tuple[float]
Tuple of step size between each patch along x1 and x2 axis.
Returns
Returns:
-------
bbox: List[float] ## check type of return.
sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max]
bbox: List[float]
Sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max].
"""
# define patch size in x1/x2
x1_extend, x2_extend = patch_size
Expand All @@ -1458,7 +1450,7 @@ def sample_sliding_window(
patch_list = []

# Todo: simplify these elif statements
if self.coord_directions['x1'] == False and self.coord_directions['x2'] == True:
if self.coord_directions["x1"] == False and self.coord_directions["x2"] == True:
for y in np.arange(x1_max, x1_min, -dy):
for x in np.arange(x2_min, x2_max, dx):
if y - x1_extend < x1_min:
Expand All @@ -1474,7 +1466,10 @@ def sample_sliding_window(
bbox = [y0 - x1_extend, y0, x0, x0 + x2_extend]
patch_list.append(bbox)

elif self.coord_directions['x1'] == False and self.coord_directions['x2'] == False:
elif (
self.coord_directions["x1"] == False
and self.coord_directions["x2"] == False
):
for y in np.arange(x1_max, x1_min, -dy):
for x in np.arange(x2_max, x2_min, -dx):
if y - x1_extend < x1_min:
Expand All @@ -1490,7 +1485,9 @@ def sample_sliding_window(
bbox = [y0 - x1_extend, y0, x0 - x2_extend, x0]
patch_list.append(bbox)

elif self.coord_directions['x1'] == True and self.coord_directions['x2'] == False:
elif (
self.coord_directions["x1"] == True and self.coord_directions["x2"] == False
):
for y in np.arange(x1_min, x1_max, dy):
for x in np.arange(x2_max, x2_min, -dx):
if y + x1_extend > x1_max:
Expand Down Expand Up @@ -1662,7 +1659,6 @@ def __call__(
)

elif patch_strategy == "random":

if patch_size is None:
raise ValueError(
"Patch size must be specified for random patch sampling"
Expand Down Expand Up @@ -1751,7 +1747,7 @@ def __call__(
datewise_deterministic=datewise_deterministic,
seed_override=seed_override,
patch_size=patch_size,
stride=stride
stride=stride,
)
for bbox in bboxes
]
Expand All @@ -1768,7 +1764,7 @@ def __call__(
datewise_deterministic=datewise_deterministic,
seed_override=seed_override,
patch_size=patch_size,
stride=stride
stride=stride,
)
for bbox in bboxes
]
Expand Down
Loading

0 comments on commit 9e4254f

Please sign in to comment.