diff --git a/dreem/datasets/base_dataset.py b/dreem/datasets/base_dataset.py index 53bea40..901cb38 100644 --- a/dreem/datasets/base_dataset.py +++ b/dreem/datasets/base_dataset.py @@ -6,6 +6,9 @@ from typing import Union import numpy as np import torch +import logging + +logger = logging.getLogger("dreem.datasets") class BaseDataset(Dataset): @@ -110,6 +113,23 @@ def create_chunks(self) -> None: self.label_idx = [self.label_idx[i] for i in sample_idx] + # workaround for empty batch bug (needs to be changed). Check for batch with with only 1/10 size of clip length. Arbitrary thresholds + remove_idx = [] + for i, frame_chunk in enumerate(self.chunked_frame_idx): + if ( + len(frame_chunk) + <= min(int(self.clip_length / 10), 5) + # and frame_chunk[-1] % self.clip_length == 0 + ): + logger.warning( + f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading" + ) + remove_idx.append(i) + if len(remove_idx) > 0: + for i in sorted(remove_idx, reverse=True): + self.chunked_frame_idx.pop(i) + self.label_idx.pop(i) + else: self.chunked_frame_idx = self.frame_idx self.label_idx = [i for i in range(len(self.labels))] diff --git a/dreem/inference/post_processing.py b/dreem/inference/post_processing.py index 9b78a95..a29a178 100644 --- a/dreem/inference/post_processing.py +++ b/dreem/inference/post_processing.py @@ -160,7 +160,8 @@ def filter_max_center_dist( valid = dist.squeeze() < max_center_dist # n_k x n_nonk # handle case where id_inds and valid is a single value # handle this better - if valid.ndim == 0: valid = valid.unsqueeze(0) + if valid.ndim == 0: + valid = valid.unsqueeze(0) if valid.ndim == 1: if id_inds.shape[0] == 1: valid_mult = valid.float().unsqueeze(-1)