From 76da2370466b1c0d4250405ae48bc72e16f596f6 Mon Sep 17 00:00:00 2001 From: shaikh58 Date: Tue, 4 Feb 2025 07:20:45 +0000 Subject: [PATCH 1/2] - bug fix implemented --- dreem/datasets/sleap_dataset.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/dreem/datasets/sleap_dataset.py b/dreem/datasets/sleap_dataset.py index 7a98fa4..08cf2bd 100644 --- a/dreem/datasets/sleap_dataset.py +++ b/dreem/datasets/sleap_dataset.py @@ -143,7 +143,11 @@ def __init__( # for label in self.labels: # label.remove_empty_instances(keep_empty_frames=False) - self.frame_idx = [torch.arange(len(labels)) for labels in self.labels] + # note if slp is missing frames, taking last frame idx is safer than len(labels) + # as there will be fewer labeledframes than actual frames + self.frame_idx = [torch.arange(labels[-1].frame_idx + 1) for labels in self.labels] + self.skipped_frame_ct = [0 for labels in self.labels] + # self.frame_idx = [torch.arange(len(labels)) for labels in self.labels] # Method in BaseDataset. Creates label_idx and chunked_frame_idx to be # used in call to get_instances() self.create_chunks() @@ -168,7 +172,6 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram """ video = self.labels[label_idx] - video_name = self.video_files[label_idx] # get the correct crop size based on the video @@ -201,7 +204,16 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram frame_ind = int(frame_ind) - lf = video[frame_ind] + # if slp is missing instances in some frames, frame_ind will be smaller than lf.frame_idx + lf = video[frame_ind - self.skipped_frame_ct[label_idx]] + if frame_ind < lf.frame_idx: + logger.warning( + f"Frame index {frame_ind} is trying to access frame {lf.frame_idx} of the slp file {video_name}. " + f"This likely means there are no labelled instances in this frame. Skipping frame." + ) + self.skipped_frame_ct[label_idx] += 1 + continue + try: img = vid_reader.get_data(int(lf.frame_idx)) From 2dfefe2a6c474b75efb33893706c83f1d491b28c Mon Sep 17 00:00:00 2001 From: shaikh58 Date: Tue, 4 Feb 2025 07:33:06 +0000 Subject: [PATCH 2/2] lint --- dreem/datasets/sleap_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dreem/datasets/sleap_dataset.py b/dreem/datasets/sleap_dataset.py index 08cf2bd..4e47107 100644 --- a/dreem/datasets/sleap_dataset.py +++ b/dreem/datasets/sleap_dataset.py @@ -145,7 +145,9 @@ def __init__( # note if slp is missing frames, taking last frame idx is safer than len(labels) # as there will be fewer labeledframes than actual frames - self.frame_idx = [torch.arange(labels[-1].frame_idx + 1) for labels in self.labels] + self.frame_idx = [ + torch.arange(labels[-1].frame_idx + 1) for labels in self.labels + ] self.skipped_frame_ct = [0 for labels in self.labels] # self.frame_idx = [torch.arange(len(labels)) for labels in self.labels] # Method in BaseDataset. Creates label_idx and chunked_frame_idx to be @@ -214,7 +216,6 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram self.skipped_frame_ct[label_idx] += 1 continue - try: img = vid_reader.get_data(int(lf.frame_idx)) except IndexError as e: