Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential bug: TaskLoader does not run when patching and using 'Gapfill' samlping strategy #9

Open
MartinSJRogers opened this issue Oct 11, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@MartinSJRogers
Copy link
Collaborator

MartinSJRogers commented Oct 11, 2024

Description

The Taskloader function seems to freeze when asked to generate patched functions using the gapfill sampling strategy:

Reproduction steps

This first example runs fine

# Instantiate task loader
task_loader = TaskLoader(
    context = [modis_ds, amsr_ds] *5,
    target = modis_ds,
    context_delta_t = [-2, -2, -1, -1, 0, 0, 1,1,2,2], 
    target_delta_t = 0,
    #links = [(4, 0)],
)
for date in tqdm(dates, disable=not progress):
    tasks_per_date = task_loader(date, context_sampling=["all", "all", "all", "all", "all",
                                                                "all", "all", "all", "all", "all"],
                            target_sampling="all",
                            patch_strategy="sliding",
                            patch_size=0.5,
                            stride=0.25
            )
            for task in tasks_per_date:
                task.remove_context_nans().remove_target_nans()
            train_tasks.extend(tasks_per_date)

But if you call this example, nothing happens:

# Instantiate task loader
task_loader = TaskLoader(
    context = [modis_ds, amsr_ds] *5,
    target = modis_ds,
    context_delta_t = [-2, -2, -1, -1, 0, 0, 1,1,2,2], 
    target_delta_t = 0,
    links = [(4, 0)],
)
## Code runs down to here and then stops:
for date in tqdm(dates, disable=not progress):
     tasks_per_date = task_loader(date, context_sampling=["all", "all", "all", "all", "gapfill",
                                                                "all", "all", "all", "all", "all"],
                            target_sampling="gapfill",
                            patch_strategy="sliding",
                            patch_size=0.5,
                            stride=0.25
            )
            for task in tasks_per_date:
                task.remove_context_nans().remove_target_nans()
            train_tasks.extend(tasks_per_date)

Version

Patchwise_train fork, monotonic errors branch,

Screenshots

![DESCRIPTION](LINK.png)

OS

Windows

@MartinSJRogers MartinSJRogers added the bug Something isn't working label Oct 11, 2024
@davidwilby davidwilby self-assigned this Oct 11, 2024
@davidwilby
Copy link
Owner

Could you maybe send your code (privately) so I can see whether this is relating to the datasets used? I haven't been able to replicate this behaviour so far, though using a different dataset I expect.

Have you run this code with a debugger? I'd like to check 1. Is it really hanging indefinitely, or just taking a long time, and if so 2. where is the freezing happening, is it actually in TaskLoader.__call__?

@MartinSJRogers
Copy link
Collaborator Author

Thanks @davidwilby I will send a private gist now. I haven't run this code with a debugger, I have just used my trusted print statements all over the place again. The print statements prove the code is hanging indefinitely/ taking a very long time when generating tasks, but I guess the debugger is needed to identify where in the actual DeepSensor codebase the code is hanging indefinitely?
In the code I have sent you, if you set patching to False the code runs fine.

@davidwilby
Copy link
Owner

In running your example @MartinSJRogers , I've tracked the hanging behaviour down to this section of deepsensor.data.TaskLoader.task_generation:

# Keep trying until we get a target set with at least one target point
keep_searching = True
while keep_searching:
added_mask_date = rng.choice(self.context[context_idx].time)
added_mask = (
self.context[context_idx].sel(time=added_mask_date).isnull()
)
curr_mask = context_var.isnull()
# Mask out added missing values
context_var = context_var.where(~added_mask)
# TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs
# when we could just slice the target values here
target_mask = added_mask & ~curr_mask
if isinstance(target_var, xr.Dataset):
keep_searching = np.all(target_mask.to_array().data == False)
else:
keep_searching = np.all(target_mask.data == False)
if keep_searching:
continue # No target points -- use a different `added_mask`
target_var = target_var.where(
target_mask
) # Only keep target locations
context_slices[context_idx] = context_var
target_slices[target_idx] = target_var

In this section keep_searching is never being set to True for some or all of your patches I think, so this while loop just runs indefinitely.

target_mask = added_mask & ~curr_mask
if isinstance(target_var, xr.Dataset):
    keep_searching = np.all(target_mask.to_array().data == False)
else:
    keep_searching = np.all(target_mask.data == False)
if keep_searching:
    continue  # No target points -- use a different `added_mask`

@davidwilby
Copy link
Owner

Resolved by #11 - moving the spatial slicing operation after the gapfill operation.

Will also be addressed by #12 - but at a later date.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants