diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index e5d10d2f..06ed6332 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1365,9 +1365,11 @@ def sample_variable(var, sampling_strat, seed): rng = np.random.default_rng(split_seed) # Keep trying until we get a target set with at least one target point - keep_searching = True - while keep_searching: - added_mask_date = rng.choice(self.context[context_idx].time) + # Iterate through a randomly ordered list of dates + added_mask_dates = rng.choice(self.context[context_idx].time, + size=len(self.context[context_idx].time), + replace=False) + for i, added_mask_date in enumerate(added_mask_dates): added_mask = ( self.context[context_idx].sel(time=added_mask_date).isnull() ) @@ -1379,19 +1381,23 @@ def sample_variable(var, sampling_strat, seed): # TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs # when we could just slice the target values here target_mask = added_mask & ~curr_mask - if isinstance(target_var, xr.Dataset): - keep_searching = np.all(target_mask.to_array().data == False) + target_mask_data = target_mask.to_array().data if isinstance(target_var, xr.Dataset) else target_mask.data + # if all elements in target_mask_data are NaN + if np.all(target_mask_data == False): + if i == len(added_mask_dates): + # if the last date in the list has been reached and no suitable mask is found + raise ValueError("") + else: + # otherwise continue on with next random date + continue else: - keep_searching = np.all(target_mask.data == False) - if keep_searching: - continue # No target points -- use a different `added_mask` + target_var = target_var.where( + target_mask + ) # Only keep target locations - target_var = target_var.where( - target_mask - ) # Only keep target locations - - context_slices[context_idx] = context_var - target_slices[target_idx] = target_var + context_slices[context_idx] = context_var + target_slices[target_idx] = target_var + break for i, (var, sampling_strat) in enumerate( zip(context_slices, context_sampling)