diff --git a/dreem/datasets/sleap_dataset.py b/dreem/datasets/sleap_dataset.py index 7a98fa4..4e47107 100644 --- a/dreem/datasets/sleap_dataset.py +++ b/dreem/datasets/sleap_dataset.py @@ -143,7 +143,13 @@ def __init__( # for label in self.labels: # label.remove_empty_instances(keep_empty_frames=False) - self.frame_idx = [torch.arange(len(labels)) for labels in self.labels] + # note if slp is missing frames, taking last frame idx is safer than len(labels) + # as there will be fewer labeledframes than actual frames + self.frame_idx = [ + torch.arange(labels[-1].frame_idx + 1) for labels in self.labels + ] + self.skipped_frame_ct = [0 for labels in self.labels] + # self.frame_idx = [torch.arange(len(labels)) for labels in self.labels] # Method in BaseDataset. Creates label_idx and chunked_frame_idx to be # used in call to get_instances() self.create_chunks() @@ -168,7 +174,6 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram """ video = self.labels[label_idx] - video_name = self.video_files[label_idx] # get the correct crop size based on the video @@ -201,7 +206,15 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram frame_ind = int(frame_ind) - lf = video[frame_ind] + # if slp is missing instances in some frames, frame_ind will be smaller than lf.frame_idx + lf = video[frame_ind - self.skipped_frame_ct[label_idx]] + if frame_ind < lf.frame_idx: + logger.warning( + f"Frame index {frame_ind} is trying to access frame {lf.frame_idx} of the slp file {video_name}. " + f"This likely means there are no labelled instances in this frame. Skipping frame." + ) + self.skipped_frame_ct[label_idx] += 1 + continue try: img = vid_reader.get_data(int(lf.frame_idx))