Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid infinite loop in gapfill strategy #12

Draft
wants to merge 1 commit into
base: patchwise_train
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions deepsensor/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,9 +1365,11 @@ def sample_variable(var, sampling_strat, seed):
rng = np.random.default_rng(split_seed)

# Keep trying until we get a target set with at least one target point
keep_searching = True
while keep_searching:
added_mask_date = rng.choice(self.context[context_idx].time)
# Iterate through a randomly ordered list of dates
added_mask_dates = rng.choice(self.context[context_idx].time,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a sensible approach to me- happy with this amendment.

size=len(self.context[context_idx].time),
replace=False)
for i, added_mask_date in enumerate(added_mask_dates):
added_mask = (
self.context[context_idx].sel(time=added_mask_date).isnull()
)
Expand All @@ -1379,19 +1381,23 @@ def sample_variable(var, sampling_strat, seed):
# TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs
# when we could just slice the target values here
target_mask = added_mask & ~curr_mask
if isinstance(target_var, xr.Dataset):
keep_searching = np.all(target_mask.to_array().data == False)
target_mask_data = target_mask.to_array().data if isinstance(target_var, xr.Dataset) else target_mask.data
# if all elements in target_mask_data are NaN
if np.all(target_mask_data == False):
if i == len(added_mask_dates):
# if the last date in the list has been reached and no suitable mask is found
raise ValueError("")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be content within the "" here?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should indeed, I don't think I understand enough to write an informative error message here, can you suggest one?

else:
# otherwise continue on with next random date
continue
else:
keep_searching = np.all(target_mask.data == False)
if keep_searching:
continue # No target points -- use a different `added_mask`
target_var = target_var.where(
target_mask
) # Only keep target locations

target_var = target_var.where(
target_mask
) # Only keep target locations

context_slices[context_idx] = context_var
target_slices[target_idx] = target_var
context_slices[context_idx] = context_var
target_slices[target_idx] = target_var
break

for i, (var, sampling_strat) in enumerate(
zip(context_slices, context_sampling)
Expand Down
Loading