-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
avoid infinite loop in gapfill strategy #12
base: patchwise_train
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1365,9 +1365,11 @@ def sample_variable(var, sampling_strat, seed): | |
rng = np.random.default_rng(split_seed) | ||
|
||
# Keep trying until we get a target set with at least one target point | ||
keep_searching = True | ||
while keep_searching: | ||
added_mask_date = rng.choice(self.context[context_idx].time) | ||
# Iterate through a randomly ordered list of dates | ||
added_mask_dates = rng.choice(self.context[context_idx].time, | ||
size=len(self.context[context_idx].time), | ||
replace=False) | ||
for i, added_mask_date in enumerate(added_mask_dates): | ||
added_mask = ( | ||
self.context[context_idx].sel(time=added_mask_date).isnull() | ||
) | ||
|
@@ -1379,19 +1381,23 @@ def sample_variable(var, sampling_strat, seed): | |
# TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs | ||
# when we could just slice the target values here | ||
target_mask = added_mask & ~curr_mask | ||
if isinstance(target_var, xr.Dataset): | ||
keep_searching = np.all(target_mask.to_array().data == False) | ||
target_mask_data = target_mask.to_array().data if isinstance(target_var, xr.Dataset) else target_mask.data | ||
# if all elements in target_mask_data are NaN | ||
if np.all(target_mask_data == False): | ||
if i == len(added_mask_dates): | ||
# if the last date in the list has been reached and no suitable mask is found | ||
raise ValueError("") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be content within the "" here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should indeed, I don't think I understand enough to write an informative error message here, can you suggest one? |
||
else: | ||
# otherwise continue on with next random date | ||
continue | ||
else: | ||
keep_searching = np.all(target_mask.data == False) | ||
if keep_searching: | ||
continue # No target points -- use a different `added_mask` | ||
target_var = target_var.where( | ||
target_mask | ||
) # Only keep target locations | ||
|
||
target_var = target_var.where( | ||
target_mask | ||
) # Only keep target locations | ||
|
||
context_slices[context_idx] = context_var | ||
target_slices[target_idx] = target_var | ||
context_slices[context_idx] = context_var | ||
target_slices[target_idx] = target_var | ||
break | ||
|
||
for i, (var, sampling_strat) in enumerate( | ||
zip(context_slices, context_sampling) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a sensible approach to me- happy with this amendment.