From e4386d88d021797c3f7125512182f033343a1ce1 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 2 Nov 2023 10:44:19 -0700 Subject: [PATCH 01/40] fix issue tracking across gaps of no instances fix bug where tracks dont persist across chunks add verbosity for debugging(temp) --- biogtr/inference/tracker.py | 102 ++++++++++++++++++++++++++---------- 1 file changed, 74 insertions(+), 28 deletions(-) diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 01625873..eaefbe26 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -1,6 +1,7 @@ """Module containing logic for going from association -> assignment.""" import torch import pandas as pd +import warnings from biogtr.models import model_utils from biogtr.inference import post_processing from biogtr.inference.boxes import Boxes @@ -22,7 +23,9 @@ def __init__( decay_time: float = None, iou: str = None, max_center_dist: float = None, - persistent_tracking: bool = False + persistent_tracking: bool = False, + max_gap: int = -1, + verbose = False ): """Initialize a tracker to run inference. @@ -47,6 +50,14 @@ def __init__( self.iou = iou self.max_center_dist = max_center_dist self.persistent_tracking = persistent_tracking + self.verbose = verbose + + self.max_gap = max_gap + self.curr_gap = 0 + if self.max_gap >=0 and self.max_gap <= self.window_size: + self.max_gap = self.window_size + + self.id_count = 0 def __call__(self, model: GlobalTrackingTransformer, instances: list[dict], all_instances: list = None): """Wrapper around `track` to enable `tracker()` instead of `tracker.track()`. @@ -109,8 +120,9 @@ def track(self, model: GlobalTrackingTransformer, instances: list[dict], all_ins ) if not self.persistent_tracking: - # print(f'Clearing Queue after tracking') + if self.verbose: warnings.warn(f'Clearing Queue after tracking') self.track_queue.clear() + self.id_count = 0 return instances_pred @@ -151,46 +163,80 @@ def sliding_inference(self, model: GlobalTrackingTransformer, instances, window_ # W: width. video_len = len(instances) - id_count = 0 + id_count = self.id_count for batch_idx in range(video_len): - - if (self.persistent_tracking and instances[batch_idx]['frame_id'] == 0): + + if self.verbose: + warnings.warn(f"Current number of tracks is {id_count}") + + if (self.persistent_tracking and instances[batch_idx]['frame_id'] == 0): #check for new video and clear queue self.track_queue.clear() + self.id_count = 0 + ''' + Initialize tracks on first frame of video or first instance of detections. + ''' if len(self.track_queue) == 0 or sum([len(frame["pred_track_ids"]) for frame in self.track_queue]) == 0: - # print(f'Initializing track on batch {batch_idx} frame {instances[batch_idx]["frame_id"]}') + + if self.verbose: warnings.warn(f'Initializing track on batch {batch_idx} frame {instances[batch_idx]["frame_id"].item()}') + instances[batch_idx]["pred_track_ids"] = torch.arange( 0, len(instances[batch_idx]["bboxes"]) ) id_count = len(instances[batch_idx]["bboxes"]) - # print(f'Initial tracks are {instances[batch_idx]["pred_track_ids"]}') - self.track_queue.append(instances[batch_idx]) + + if self.verbose: warnings.warn(f'Initial tracks are {instances[batch_idx]["pred_track_ids"].cpu().tolist()}') + + if instances[batch_idx]['num_detected'] > 0: + + self.track_queue.append(instances[batch_idx]) + self.curr_gap = 0 + else: + self.curr_gap += 1 + if self.verbose: warnings.warn(f"No detections in frame {batch_idx}, {instances[batch_idx]['frame_id'].item()}. Skipping frame in queue. Current gap size: {self.curr_gap}") - else: - instances_to_track = (list(self.track_queue) + [instances[batch_idx]])[-window_size:] + else: - if sum([frame['num_detected'] for frame in instances_to_track]) == 0: - print("No detections to track!") + if instances[batch_idx]['num_detected'] == 0: #Check if there are detections. If there are skip and increment gap count instances[batch_idx]["pred_track_ids"] = torch.arange( - 0, len(instances[batch_idx]["bboxes"]) - ) + 0, len(instances[batch_idx]["bboxes"]) + ) + self.curr_gap += 1 + if self.verbose: warnings.warn(f"No detections in frame {batch_idx}, {instances[batch_idx]['frame_id'].item()}. Skipping frame in queue. Current gap size: {self.curr_gap}") + + + else: #detections found. Track and reset gap counter + self.curr_gap = 0 + + instances_to_track = (list(self.track_queue) + [instances[batch_idx]])[-window_size:] + + if len(self.track_queue) == self.track_queue.maxlen: + tracked_frame = self.track_queue.pop() + tracked_frame["tracked"] = True + self.track_queue.append(instances[batch_idx]) - continue + + query_ind = min(window_size - 1, len(instances_to_track) - 1) + + instances[batch_idx], id_count = self._run_global_tracker( + model, + instances_to_track, + query_frame=query_ind, + id_count=id_count, + overlap_thresh=self.overlap_thresh, + mult_thresh=self.mult_thresh, + ) - query_ind = min(window_size - 1, len(instances_to_track) - 1) + if self.curr_gap == self.max_gap: #Check if we've reached the max gap size and reset tracks. - instances[batch_idx], id_count = self._run_global_tracker( - model, - instances_to_track, - query_frame=query_ind, - id_count=id_count, - overlap_thresh=self.overlap_thresh, - mult_thresh=self.mult_thresh, - ) + if self.verbose: warnings.warn(f"Number of consecutive frames with missing detections has exceeded threshold of {self.max_gap}!") + + self.track_queue.clear() + self.curr_gap = 0 """ # If first frame. @@ -214,10 +260,10 @@ def sliding_inference(self, model: GlobalTrackingTransformer, instances, window_ # instances[frame_id - window_size]["features"] = None # TODO: Insert postprocessing. - - for frame in instances[:len(instances)-window_size]: - frame["features"] = frame["features"].cpu() - + # for frame in instances: + # if "tracked" in frame.keys(): + # frame['features'] = frame['features'].cpu() + self.id_count = id_count return instances def _run_global_tracker(self, model: GlobalTrackingTransformer, instances, query_frame, id_count, overlap_thresh, mult_thresh): From 91d8b4154947d989680df993ab6d1551a51625e7 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 2 Nov 2023 10:46:03 -0700 Subject: [PATCH 02/40] add notebooks folder to gitignore update base train yaml --- biogtr/training/configs/base.yaml | 33 ++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/biogtr/training/configs/base.yaml b/biogtr/training/configs/base.yaml index 5088b1c8..f7069f40 100644 --- a/biogtr/training/configs/base.yaml +++ b/biogtr/training/configs/base.yaml @@ -55,30 +55,35 @@ tracker: max_center_dist: null runner: - train_metrics: [""] - val_metrics: ["sw_cnt"] - test_metrics: ["sw_cnt"] - + metrics: + train: ['num_switches'] + val: ['num_switches'] + test: ['num_switches'] + persistent_tracking: + train: false + val: true + test: true + dataset: train_dataset: - slp_files: ['190612_110405_wt_18159111_rig2.2@11730.slp'] - video_files: ['190612_110405_wt_18159111_rig2.2@11730.mp4'] + slp_files: ["../../tests/data/sleap/two_flies.slp"] + video_files: ["../../tests/data/sleap/two_flies.mp4"] padding: 5 crop_size: 128 chunk: true clip_length: 32 val_dataset: - slp_files: ['190612_110405_wt_18159111_rig2.2@11730.slp'] - video_files: ['190612_110405_wt_18159111_rig2.2@11730.mp4'] + slp_files: ["../../tests/data/sleap/two_flies.slp"] + video_files: ["../../tests/data/sleap/two_flies.mp4"] padding: 5 crop_size: 128 chunk: True clip_length: 32 test_dataset: - slp_files: ['190612_110405_wt_18159111_rig2.2@11730.slp'] - video_files: ['190612_110405_wt_18159111_rig2.2@11730.mp4'] + slp_files: ["../../tests/data/sleap/two_flies.slp"] + video_files: ["../../tests/data/sleap/two_flies.mp4"] padding: 5 crop_size: 128 chunk: True @@ -96,6 +101,7 @@ dataloader: num_workers: 0 logging: + logger_type: null name: "example_train" entity: null job_type: "train" @@ -116,7 +122,7 @@ early_stopping: divergence_threshold: 30 checkpointing: - monitor: ["val_loss","val_sw_cnt"] + monitor: ["val_loss","val_num_switches"] verbose: true save_last: true dirpath: null @@ -133,3 +139,8 @@ trainer: log_every_n_steps: 1 max_epochs: 100 min_epochs: 10 + +view_batch: + enable: False + num_frames: 0 + no_train: False From 51200a0c860610d66cd39eb77645e2b036f2c0d1 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Wed, 8 Nov 2023 17:35:20 -0800 Subject: [PATCH 03/40] add `notebooks` folder to git ignore(temp) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b24819fd..1e8f6ba9 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,7 @@ target/ # Jupyter Notebook .ipynb_checkpoints +notebooks/ # IPython profile_default/ From b37d1999052eb653db90f2bd28b963172540c3b0 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Wed, 8 Nov 2023 17:36:15 -0800 Subject: [PATCH 04/40] Add `Frame` and `Instance` objects for data manipulation --- biogtr/data_structures.py | 708 +++++++++++++++++++ biogtr/datasets/base_dataset.py | 16 +- biogtr/datasets/cell_tracking_dataset.py | 63 +- biogtr/datasets/eval_dataset.py | 40 +- biogtr/datasets/microscopy_dataset.py | 65 +- biogtr/datasets/sleap_dataset.py | 86 +-- biogtr/inference/metrics.py | 75 +- biogtr/inference/track.py | 27 +- biogtr/inference/tracker.py | 194 ++--- biogtr/models/global_tracking_transformer.py | 31 +- biogtr/models/model_utils.py | 9 +- biogtr/models/transformer.py | 36 +- tests/test_datasets.py | 32 +- tests/test_inference.py | 99 ++- tests/test_models.py | 62 +- 15 files changed, 1054 insertions(+), 489 deletions(-) create mode 100644 biogtr/data_structures.py diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py new file mode 100644 index 00000000..c104b71e --- /dev/null +++ b/biogtr/data_structures.py @@ -0,0 +1,708 @@ +"""Module containing data classes such as Instances and Frames.""" +import torch +from numpy.typing import ArrayLike +from typing import Union, List + +class Instance: + """Class representing a single instance to be tracked.""" + + def __init__( + self, + gt_track_id: int = None, + pred_track_id: int = -1, + bbox: ArrayLike = torch.empty((0, 4)), + crop: ArrayLike = torch.tensor([]), + features: ArrayLike = torch.tensor([]), + device: str = None, + ): + """Initialize Instance. + + Args: + gt_track_id: Ground truth track id - only used for train/eval. + pred_track_id: Predicted track id. Untracked instance is represented by -1. + bbox: The bounding box coordinate of the instance. Defaults to an empty tensor. + crop: The crop of the instance. + features: The reid features extracted from the CNN backbone used in the transformer. + device: String representation of the device the instance should be on. + """ + if gt_track_id is not None: + self._gt_track_id = torch.tensor([gt_track_id]) + else: + self._gt_track_id = torch.tensor([]) + + if pred_track_id is not None: + self._pred_track_id = torch.tensor([pred_track_id]) + else: + self._pred_track_id = torch.tensor([]) + + if not isinstance(bbox, torch.Tensor): + self._bbox = torch.tensor(bbox) + else: + self._bbox = bbox + + if self._bbox.shape[0] and len(self._bbox.shape) == 1: + self._bbox = self._bbox.unsqueeze(0) + + if not isinstance(crop, torch.Tensor): + self._crop = torch.tensor(crop) + else: + self._crop = crop + + if len(self._crop.shape) == 2: + self._crop = self._crop.unsqueeze(0).unsqueeze(0) + elif len(self._crop.shape) == 3: + self._crop = self._crop.unsqueeze(0) + + if not isinstance(crop, torch.Tensor): + self._features = torch.tensor(features) + else: + self._features = features + + if self._features.shape[0] and len(self._features.shape) == 1: + self._features = self._features.unsqueeze(0) + + self._device = device + self.to(self._device) + + def __repr__(self) -> str: + """String representation of the Instance.""" + return ( + "Instance(" + f"gt_track_id={self._gt_track_id.item()}, " + f"pred_track_id={self._pred_track_id.item()}, " + f"bbox={self._bbox}, " + f"crop={self._crop.shape}, " + f"features={self._features.shape}, " + f"device={self._device}" + ")" + ) + + def to(self, map_location): + """Move instance to different device or change dtype. (See `torch.to` for more info) + + Args: + map_location: Either the device or dtype for the instance to be moved. + Returns: + self: reference to the instance moved to correct device/dtype. + """ + self._gt_track_id = self._gt_track_id.to(map_location) + self._pred_track_id = self._pred_track_id.to(map_location) + self._bbox = self._bbox.to(map_location) + self._crop = self._crop.to(map_location) + self._features = self._features.to(map_location) + self.device = map_location + return self + + @property + def device(self) -> str: + """The device the instance is on. + + Returns: + The str representation of the device the gpu is on. + """ + return self._device + + @device.setter + def device(self, device) -> None: + """Setter for the device property. + + Args: + device: The str representation of the device. + """ + self._device = device + + @property + def gt_track_id(self) -> torch.Tensor: + """The ground truth track id of the instance. + + Returns: + A tensor containing the ground truth track id + """ + return self._gt_track_id + + @gt_track_id.setter + def gt_track_id(self, track: int): + """Function to set the instance ground-truth track id. + + Args: + track: An int representing the ground-truth track id. + """ + if track is not None: + self._gt_track_id = torch.tensor([track]) + else: + self._gt_track_id = torch.tensor([]) + + def has_gt_track_id(self) -> bool: + """Function for determining if instance has a gt track assignment. + + Returns: + True if the gt track id is set, otherwise False. + """ + if self._gt_track_id.shape[0] == 0: + return False + else: + return True + + @property + def pred_track_id(self) -> torch.Tensor: + """The track id predicted by the tracker using asso_output from model. + + Returns: + A tensor containing the predicted track id. + """ + return self._pred_track_id + + @pred_track_id.setter + def pred_track_id(self, track: int) -> None: + """Function to set predicted track id. + + Args: + track: an int representing the predicted track id. + """ + if track is not None: + self._pred_track_id = torch.tensor([track]) + else: + self._pred_track_id = torch.tensor([]) + + def has_pred_track_id(self) -> bool: + """Function to determine whether instance has predicted track id + + Returns: + True if instance has a pred track id, False otherwise. + Note that `-1` represents no assigned pred_track_id while + `[]` represents assigned track id of empty instance. + """ + if self._pred_track_id == -1: + return False + else: + return True + + @property + def bbox(self) -> torch.Tensor: + """The bounding box coordinates of the instance in the original frame + + Returns: + A (1,4) tensor containing the bounding box coordinates. + """ + return self._bbox + + @bbox.setter + def bbox(self, bbox: ArrayLike) -> None: + """Function to set the instance bounding box. + + Args: + bbox: an arraylike object containing the bounding box coordinates. + """ + if bbox is None or len(bbox) == 0: + self._bbox = torch.empty((0, 4)) + else: + if not isinstance(bbox, torch.Tensor): + self._bbox = torch.tensor(bbox) + else: + self._bbox = bbox + + if self._bbox.shape[0] and len(self._bbox.shape) == 1: + self._bbox = self._bbox.unsqueeze(0) + + def has_bbox(self) -> bool: + """Function for determining if the instance has a bbox. + + Returns: + True if the instance has a bounding box, false otherwise. + """ + if self._bbox.shape[0] == 0: + return False + else: + return True + + @property + def crop(self) -> torch.Tensor: + """The crop of the instance. + + Returns: + A (1, c, h , w) tensor containing the cropped image centered around the instance. + """ + return self._crop + + @crop.setter + def crop(self, crop: ArrayLike) -> None: + """Function to set the crop of the instance. + + Args: + an arraylike object containing the cropped image of the centered instance. + """ + if crop is None or len(crop) == 0: + self._crop = torch.tensor([]) + else: + if not isinstance(crop, torch.Tensor): + self._crop = torch.tensor(crop) + else: + self._crop = crop + + if len(self._crop.shape) == 2: + self._crop = self._crop.unsqueeze(0).unsqueeze(0) + elif len(self._crop.shape) == 3: + self._crop = self._crop.unsqueeze(0) + + def has_crop(self) -> bool: + """Function to determine if the instance has a crop. + + Returns: + True if the instance has an image otherwise False. + """ + if self._crop.shape[0] == 0: + return False + else: + return True + + @property + def features(self) -> torch.Tensor: + """ReID feature vector from backbone model to be used as input to transformer. + + Returns: + a (1, d) tensor containing the reid feature vector. + """ + return self._features + + @features.setter + def features(self, features: ArrayLike) -> None: + """Function to set the reid feature vector of the instance. + + Args: + features: a (1,d) array like object containing the reid features for the instance. + """ + if features is None or len(features) == 0: + self._features = torch.tensor([]) + + if not isinstance(features, torch.Tensor): + self._features = torch.tensor(features) + else: + self._features = features + + if self._features.shape[0] and len(self._features.shape) == 1: + self._features = self._features.unsqueeze(0) + + def has_features(self) -> bool: + """Function for determining if the instance has computed reid features. + + Returns: + True if the instance has reid features, False otherwise. + """ + if self._features.shape[0] == 0: + return False + else: + return True + + +class Frame: + """Data structure containing metadata for a single frame of a video.""" + def __init__( + self, + video_id: int, + frame_id: int, + img_shape: ArrayLike = [0, 0, 0], + instances: List[Instance] = [], + asso_output: ArrayLike = None, + matches: tuple = None, + traj_score: Union[ArrayLike, dict] = None, + device=None, + ): + """Initialize Frame. + + Args: + video_id: The video index in the dataset. + frame_id: The index of the frame in a video. + img_shape: The shape of the original frame (not the crop). + instances: A list of Instance objects that appear in the frame. + asso_output: The association matrix between instances + output directly from the transformer. + matches: matches from LSA algorithm between the instances and + available trajectories during tracking. + traj_score: Either a dict containing the association matrix + between instances and trajectories along postprocessing pipeline + or a single association matrix. + device: The device the frame should be moved to. + """ + self._video_id = torch.tensor([video_id]) + self._frame_id = torch.tensor([frame_id]) + + if isinstance(img_shape, torch.Tensor): + self._img_shape = img_shape + else: + self._img_shape = torch.tensor([img_shape]) + + self._instances = instances + + self._asso_output = asso_output + self._matches = matches + + if isinstance(traj_score, dict): + self._traj_score = traj_score + else: + self._traj_score = {"initial": traj_score} + + self._device = device + self.to(device) + + def __repr__(self) -> str: + """String representation of the Frame. + + Returns: + The string representation of the frame. + """ + return ( + "Frame(" + f"video_id={self._video_id.item()}, " + f"frame_id={self._frame_id.item()}, " + f"img_shape={self._img_shape}, " + f"num_detected={self.num_detected}, " + f"asso_output={self._asso_output}, " + f"traj_score={self._traj_score}, " + f"matches={self._matches}, " + f"instances={self._instances}, " + f"device={self._device}" + ")" + ) + + def to(self, map_location: str): + """Function for moving frame to different device or dtype (See `torch.to` for more info). + + Args: + map_location: A string representing the device to move to. + Returns: + The frame moved to a different device/dtype. + """ + self._video_id = self._video_id.to(map_location) + self._frame_id = self._frame_id.to(map_location) + self._img_shape = self._img_shape.to(map_location) + + if isinstance(self._asso_output, torch.Tensor): + self._asso_output = asso_output.to(map_location) + + if isinstance(self._matches, torch.Tensor): + self._matches = matches.to(map_location) + + for key, val in self._traj_score.items(): + if isinstance(val, torch.Tensor): + self._traj_score[key] = val.to(map_location) + + for instance in self._instances: + instance = instance.to(map_location) + + self._device = map_location + return self + + @property + def device(self) -> str: + """The device the frame is on. + + Returns: + The string representation of the device the frame is on. + """ + return self._device + + @device.setter + def device(self, device: str) -> None: + """Function to set the device. + + Note: Do not set `frame.device = device` normally. Use `frame.to(device)` instead. + + Args: + device: the device the function should be on. + """ + self._device = device + + + @property + def video_id(self) -> torch.Tensor: + """ + The index of the video the frame comes from. + + Returns: + A tensor containing the video index. + """ + return self._video_id + + @video_id.setter + def video_id(self, video_id: int) -> None: + """Function for setting the video index. + + Note: Generally the video_id should be immutable after initialization. + + Args: + video_id: an int representing the index of the video that the frame came from. + """ + self._video_id = torch.tensor([video_id]) + + @property + def frame_id(self) -> torch.Tensor: + """The index of the frame in a full video. + + Returns: + A torch tensor containing the index of the frame in the video. + """ + return self._frame_id + + @frame_id.setter + def frame_id(self, frame_id: int) -> None: + """Function for setting the frame index of the frame. + + Note: The frame_id should generally be immutable after initialization. + + Args: + frame_id: The int index of the frame in the full video. + """ + self._frame_id = torch.tensor([frame_id]) + + @property + def img_shape(self) -> torch.Tensor: + """The shape of the pre-cropped frame. + + Returns: + A torch tensor containing the shape of the frame. Should generally be (c, h, w) + """ + return self._img_shape + + @img_shape.setter + def img_shape(self, img_shape: ArrayLike) -> None: + """Function for setting the shape of the frame image + + Note: the img_shape should generally be immutable after initialization. + + Args: + img_shape: an ArrayLike object containing the shape of the frame image. + """ + if isinstance(img_shape, torch.Tensor): + self._img_shape = img_shape + else: + self._img_shape = torch.tensor([img_shape]) + + @property + def instances(self) -> List[Instance]: + """A list of instances in the frame. + + Returns: + The list of instances that appear in the frame. + """ + return self._instances + + @instances.setter + def instances(self, instances: List[Instance]) -> None: + """Function for setting the frame's instance + + Args: + instances: A list of Instances that appear in the frame. + """ + self._instances = instances + + def has_instances(self) -> bool: + """Function for determining whether there are instances in the frame. + + Returns: + True if there are instances in the frame, otherwise False. + """ + if self.num_detected == 0: + return False + return True + + @property + def num_detected(self) -> int: + """The number of instances in the frame. + + Returns: + the number of instances in the frame. + """ + return len(self.instances) + + @property + def asso_output(self) -> ArrayLike: + """The association matrix between instances outputed directly by transformer. + + Returns: + An arraylike (n_query, n_nonquery) association matrix between instances. + """ + return self._asso_output + + def has_asso_output(self) -> bool: + """Function for determining whether the frame has an association matrix computed. + + Returns: + True if the frame has an association matrix otherwise, False. + """ + if self._asso_output is None or len(self._asso_output) == 0: + return False + return True + + @asso_output.setter + def asso_output(self, asso_output: ArrayLike) -> None: + """Function for setting the association matrix of a frame. + + Args: + asso_output: An arraylike (n_query, n_nonquery) association matrix between instances.""" + self._asso_output = asso_output + + @property + def matches(self) -> tuple: + """Matches between frame instances and availabel trajectories. + + Returns: + A tuple containing the instance idx and trajectory idx for the matched instance. + """ + return self._matches + + @matches.setter + def matches(self, matches: tuple) -> None: + """Function for setting the frame matches + + Args: + matches: A tuple containing the instance idx and trajectory idx for the matched instance. + """ + self._matches = matches + + def has_matches(self) -> bool: + """Function for whether or not matches have been computed for frame. + + Returns: + True if frame contains matches otherwise False. + """ + if self._matches is not None and len(self._matches) > 0: + return True + return False + + def get_traj_score(self, key=None) -> Union[dict, ArrayLike, None]: + """Dictionary containing association matrix between instances and + trajectories along postprocessing pipeline. + + Args: + key: The key of the trajectory score to be accessed. + Can be one of {None, 'initial', 'decay_time', 'max_center_dist', 'iou', 'final'} + Returns: + - dictionary containing all trajectory scores if key is None + - trajectory score associated with key + - None if the key is not found + """ + if key is None: + return self._traj_score + else: + try: + return self._traj_score[key] + except KeyError as e: + print("Could not access {key} traj_score due to {e}") + return None + + def add_traj_score(self, key, traj_score: ArrayLike) -> None: + """Function for adding trajectory score to dictionary + + Args: + key: key associated with traj score to be used in dictionary + traj_score: association matrix between instances and trajectories + """ + self._traj_score[key] = traj_score + + def has_traj_score(self) -> bool: + """Function for checking if any trajectory association matrix has been saved + + Returns: + True there is at least one association matrix otherwise, false. + """ + if len(self._traj_score) == 0: + return False + return True + + def has_gt_track_ids(self) -> bool: + """Function to check if any of frames instances has a gt track id + + Returns: + True if at least 1 instance has a gt track id otherwise False. + """ + if self.has_instances(): + return any([instance.has_gt_track_id() for instance in self.instances]) + return False + + def get_gt_track_ids(self) -> torch.Tensor: + """Function to get the gt track ids of all instances in the frame + + Returns: + an (N,) shaped tensor with the gt track ids of each instance in the frame. + """ + if not self.has_instances(): + return torch.tensor([]) + return torch.cat([instance.gt_track_id for instance in self.instances]) + + def has_pred_track_ids(self) -> bool: + """Function to check if any of frames instances has a pred track id + + Returns: + True if at least 1 instance has a pred track id otherwise False. + """ + if self.has_instances(): + return any([instance.has_pred_track_id() for instance in self.instances]) + return False + + def get_pred_track_ids(self) -> torch.Tensor: + """Function to get the pred track ids of all instances in the frame + + Returns: + an (N,) shaped tensor with the pred track ids of each instance in the frame. + """ + if not self.has_instances(): + return torch.tensor([]) + return torch.cat([instance.pred_track_id for instance in self.instances]) + + def has_bboxes(self) -> bool: + """Function to check if any of frames instances has a bounding box + + Returns: + True if at least 1 instance has a bounding box otherwise False. + """ + if self.has_instances(): + return any([instance.has_bboxes() for instance in self.instances]) + return False + + def get_bboxes(self) -> torch.Tensor: + """Function to get the bounding boxes of all instances in the frame + + Returns: + an (N,4) shaped tensor with bounding boxes of each instance in the frame. + """ + if not self.has_instances(): + return torch.empty(0,4) + return torch.cat([instance.bbox for instance in self.instances], dim=0) + + def has_crops(self) -> bool: + """Function to check if any of frames instances has a crop + + Returns: + True if at least 1 instance has a crop otherwise False. + """ + if self.has_instances(): + return any([instance.has_crop() for instance in self.instances]) + return False + + def get_crops(self) -> torch.Tensor: + """Function to get the crops of all instances in the frame + + Returns: + an (N, C, H, W) shaped tensor with crops of each instance in the frame. + """ + if not self.has_instances(): + return torch.tensor([]) + return torch.cat([instance.crop for instance in self.instances], dim=0) + + def has_features(self): + """Function to check if any of frames instances has reid features already computed + + Returns: + True if at least 1 instance have reid features otherwise False. + """ + if self.has_instances(): + return any([instance.has_features() for instance in self.instances]) + return False + + def get_features(self): + """Function to get the reid feature vectors of all instances in the frame + + Returns: + an (N, D) shaped tensor with reid feature vectors of each instance in the frame. + """ + return torch.cat([instance.features for instance in self.instances], dim=0) diff --git a/biogtr/datasets/base_dataset.py b/biogtr/datasets/base_dataset.py index b63dab74..5fd4e432 100644 --- a/biogtr/datasets/base_dataset.py +++ b/biogtr/datasets/base_dataset.py @@ -1,5 +1,6 @@ """Module containing logic for loading datasets.""" from biogtr.datasets import data_utils +from biogtr.data_structures import Frame, Instance from torch.utils.data import Dataset from typing import List, Union import numpy as np @@ -61,7 +62,7 @@ def __init__( self.labels = None self.gt_list = None - def create_chunks(self): + def create_chunks(self) -> None: """Get indexing for data. Creates both indexes for selecting dataset (label_idx) and frame in @@ -98,7 +99,7 @@ def create_chunks(self): self.chunked_frame_idx = self.frame_idx self.label_idx = [i for i in range(len(self.labels))] - def __len__(self): + def __len__(self) -> int: """Get the size of the dataset. Returns: @@ -106,7 +107,7 @@ def __len__(self): """ return len(self.chunked_frame_idx) - def no_batching_fn(self, batch): + def no_batching_fn(self, batch) -> List[Frame]: """Collate function used to overwrite dataloader batching function. Args: @@ -117,7 +118,7 @@ def no_batching_fn(self, batch): """ return batch - def __getitem__(self, idx: int) -> List[dict]: + def __getitem__(self, idx: int) -> List[Frame]: """Get an element of the dataset. Args: @@ -125,10 +126,7 @@ def __getitem__(self, idx: int) -> List[dict]: or the frame. Returns: - A list of dicts where each dict corresponds a frame in the chunk and - each value is a `torch.Tensor`. Dict elements can be seen in - subclasses - + A list of `Frame`s in the chunk containing the metadata + instance features. """ label_idx, frame_idx = self.get_indices(idx) @@ -148,7 +146,7 @@ def get_indices(self, idx: int): raise NotImplementedError("Must be implemented in subclass") def get_instances(self, label_idx: List[int], frame_idx: List[int]): - """Builds instances dict given label and frame indices. + """Builds chunk of frames. This method should be implemented in any subclass of the BaseDataset. diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 00d0db85..605dcfeb 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -2,6 +2,7 @@ from PIL import Image from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset +from biogtr.data_structures import Instance, Frame from scipy.ndimage import measurements from torch.utils.data import Dataset from torchvision.transforms import functional as tvf @@ -111,7 +112,7 @@ def get_indices(self, idx): """ return self.label_idx[idx], self.chunked_frame_idx[idx] - def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict]: + def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Frame]: """Get an element of the dataset. Args: @@ -119,34 +120,16 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict frame_idx: index of the frames Returns: - a list of dicts where each dict corresponds a frame in the chunk - and each value is a `torch.Tensor`. - - Dict Elements: - { - "video_id": The video being passed through the transformer, - "img_shape": the shape of each frame, - "frame_id": the specific frame in the entire video being used, - "num_detected": The number of objects in the frame, - "gt_track_ids": The ground truth labels, - "bboxes": The bounding boxes of each object, - "crops": The raw pixel crops, - "features": The feature vectors for each crop outputed by the - CNN encoder, - "pred_track_ids": The predicted trajectory labels from the - tracker, - "asso_output": the association matrix preprocessing, - "matches": the true positives from the model, - "traj_score": the association matrix post processing, - } + a list of Frame objects containing frame metadata and Instance Objects. + See `biogtr.data_structures` for more info. """ image = self.videos[label_idx] gt = self.labels[label_idx] - instances = [] + frames = [] for i in frame_idx: - gt_track_ids, centroids, bboxes, crops = [], [], [], [] + instances, gt_track_ids, centroids, bboxes = [], [], [], [] i = int(i) @@ -201,25 +184,17 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict img = torch.Tensor(img).unsqueeze(0) - for bbox in bboxes: - crop = data_utils.crop_bbox(img, bbox) - crops.append(crop) - - instances.append( - { - "video_id": torch.tensor([label_idx]), - "img_shape": torch.tensor([img.shape]), - "frame_id": torch.tensor([i]), - "num_detected": torch.tensor([len(bboxes)]), - "gt_track_ids": torch.tensor(gt_track_ids).type(torch.int64), - "bboxes": torch.stack(bboxes), - "crops": torch.stack(crops), - "features": torch.tensor([]), - "pred_track_ids": torch.tensor([-1 for _ in range(len(bboxes))]), - "asso_output": torch.tensor([]), - "matches": torch.tensor([]), - "traj_score": torch.tensor([]), - } - ) + for i in range(len(gt_track_ids)): + crop = data_utils.crop_bbox(img, bboxes[i]) + + instances.append(Instance(gt_track_id=gt_track_ids[i], + pred_track_id=-1, + bbox=bboxes[i], + crop=crop)) + + frames.append(Frame(video_id=label_idx, + frame_id=i, + img_shape=img.shape, + instances=instances)) - return instances + return frames diff --git a/biogtr/datasets/eval_dataset.py b/biogtr/datasets/eval_dataset.py index 870ec014..e021e9e6 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/biogtr/datasets/eval_dataset.py @@ -1,15 +1,23 @@ "Module containing wrapper for merging gt and pred datasets for evaluation" import torch from torch.utils.data import Dataset +from biogtr.data_structures import Frame, Instance +from typing import List class EvalDataset(Dataset): - def __init__(self, gt_dataset: Dataset, pred_dataset: Dataset): + def __init__(self, gt_dataset: Dataset, pred_dataset: Dataset) -> None: + """Initialize EvalDataset + + Args: + gt_dataset: A Dataset object containing ground truth track ids + pred_dataset: A dataset object containing predicted track ids + """ self.gt_dataset = gt_dataset self.pred_dataset = pred_dataset - def __len__(self): + def __len__(self) -> int: """Get the size of the dataset. Returns: @@ -17,7 +25,7 @@ def __len__(self): """ return len(self.gt_dataset) - def __getitem__(self, idx: int): + def __getitem__(self, idx: int) -> List[Frame]: """Get an element of the dataset. Args: @@ -25,15 +33,21 @@ def __getitem__(self, idx: int): or the frame. Returns: - A list of dicts where each dict corresponds a frame in the chunk and - each value is a `torch.Tensor`. Dict elements are the video id, frame id, and gt/pred track ids - + A list of Frames where frames contain instances w gt and pred track ids + bboxes. """ - labels = [{"video_id": gt_frame['video_id'], - "frame_id": gt_frame['video_id'], - "gt_track_ids": gt_frame['gt_track_ids'], - "pred_track_ids": pred_frame['gt_track_ids'], - "bboxes": pred_frame["bboxes"] - } for gt_frame, pred_frame in zip(self.gt_dataset[idx], self.pred_dataset[idx])] + gt_batch = self.gt_dataset[i] + pred_batch = self.pred_dataset[i] + + eval_frames = [] + for gt_frame, pred_frame in zip(gt_batch, pred_batch): + eval_instances = [] + for gt_instance, pred_instance in zip(gt_frame.instances, pred_frame.instances): + eval_instances.append(Instance(gt_track_id=gt_instance.gt_track_id, + pred_track_id=pred_instance.pred_track_id, + bbox=pred_instance.bbox)) + eval_frames.append(Frame(video_id=gt_frame.video_id, + frame_id=gt_frame.frame_id, + img_shape=gt_frame.img_shape, + instances=eval_instances)) - return labels \ No newline at end of file + return eval_frames \ No newline at end of file diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index 0ec3d023..062b6bd7 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -2,6 +2,7 @@ from PIL import Image from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset +from biogtr.data_structures import Frame, Instance from torch.utils.data import Dataset from torchvision.transforms import functional as tvf from typing import Union @@ -112,7 +113,7 @@ def get_indices(self, idx): """ return self.label_idx[idx], self.chunked_frame_idx[idx] - def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[dict]: + def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Frame]: """Get an element of the dataset. Args: @@ -120,26 +121,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[dict frame_idx: index of the frames Returns: - a list of dicts where each dict corresponds a frame in the chunk - and each value is a `torch.Tensor`. - - Dict Elements: - { - "video_id": The video being passed through the transformer, - "img_shape": the shape of each frame, - "frame_id": the specific frame in the entire video being used, - "num_detected": The number of objects in the frame, - "gt_track_ids": The ground truth labels, - "bboxes": The bounding boxes of each object, - "crops": The raw pixel crops, - "features": The feature vectors for each crop outputed by the - CNN encoder, - "pred_track_ids": The predicted trajectory labels from the - tracker, - "asso_output": the association matrix preprocessing, - "matches": the true positives from the model, - "traj_score": the association matrix post processing, - } + A list of Frames containing Instances to be tracked (See `biogtr.data_structures for more info`) """ labels = self.labels[label_idx] labels = labels.dropna(how="all") @@ -149,10 +131,10 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[dict if type(video) != list: video = data_utils.LazyTiffStack(self.videos[label_idx]) - instances = [] - + + frames = [] for i in frame_idx: - gt_track_ids, centroids, bboxes, crops = [], [], [], [] + instances, gt_track_ids, centroids, bboxes, crops = [], [], [], [], [] img = ( video.get_section(i) @@ -191,31 +173,22 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[dict if img.shape[2] == 3: img = img.T # todo: check for edge cases - for c in centroids: + for i in range(len(gt_track_ids)): + c = centroids[i] bbox = data_utils.pad_bbox( data_utils.get_bbox([int(c[0]), int(c[1])], self.crop_size), padding=self.padding, ) crop = data_utils.crop_bbox(img, bbox) - bboxes.append(bbox) - crops.append(crop) - - instances.append( - { - "video_id": torch.tensor([label_idx]), - "img_shape": torch.tensor([img.shape]), - "frame_id": torch.tensor([i]), - "num_detected": torch.tensor([len(bboxes)]), - "gt_track_ids": torch.tensor(gt_track_ids).type(torch.int64), - "bboxes": torch.stack(bboxes), - "crops": torch.stack(crops), - "features": torch.tensor([]), - "pred_track_ids": torch.tensor([-1 for _ in range(len(bboxes))]), - "asso_output": torch.tensor([]), - "matches": torch.tensor([]), - "traj_score": torch.tensor([]), - } - ) - - return instances + instances.append(Instance(gt_track_id=gt_track_ids[i], + pred_track_id=-1, + bbox=bbox, + crop=crop)) + + frames.append(Frame(video_id=label_idx, + frame_id=i, + img_shape=img.shape, + instances=instances)) + + return frames diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index 0a2cb33d..9474c5a6 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -5,6 +5,8 @@ import numpy as np import sleap_io as sio import random +import warnings +from biogtr.data_structures import Frame, Instance from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset from torchvision.transforms import functional as tvf @@ -27,6 +29,7 @@ def __init__( augmentations: dict = None, n_chunks: Union[int, float] = 1.0, seed: int = None, + verbose: bool = False ): """Initialize SleapDataset. @@ -51,6 +54,7 @@ def __init__( n_chunks: Number of chunks to subsample from. Can either a fraction of the dataset (ie (0,1.0]) or number of chunks seed: set a seed for reproducibility + verbose: boolean representing whether to print """ super().__init__( slp_files + video_files, @@ -73,7 +77,8 @@ def __init__( self.mode = mode self.n_chunks = n_chunks self.seed = seed - self.anchor = anchor + self.anchor = anchor.lower() + self.verbose=verbose # if self.seed is not None: # np.random.seed(self.seed) @@ -137,32 +142,23 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict img = vid_reader.get_data(0) crop_shape = (img.shape[-1], *(self.crop_size + 2 * self.padding,) * 2) - instances = [] - for i, frame in enumerate(frame_idx): - gt_track_ids, bboxes, crops, poses, shown_poses = [], [], [], [], [] + frames = [] + for i, frame_ind in enumerate(frame_idx): + instances, gt_track_ids, bboxes, crops, shown_poses = [], [], [], [], [] - frame = int(frame) + frame_ind = int(frame_ind) - lf = video[frame] + lf = video[frame_ind] try: - img = vid_reader.get_data(frame) + img = vid_reader.get_data(frame_ind) except IndexError as e: - print(f"Could not read frame {frame} from {video_name}") + print(f"Could not read frame {frame_ind} from {video_name}") continue for instance in lf: gt_track_ids.append(video.tracks.index(instance.track)) - poses.append( - dict( - zip( - [n.name for n in instance.skeleton.nodes], - np.array(instance.numpy()).tolist(), - ) - ) - ) - shown_poses.append( dict( zip( @@ -172,7 +168,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict ) ) - shown_poses = [{key: val for key, val in instance.items() + shown_poses = [{key.lower(): val for key, val in instance.items() if not np.isnan(val).any() } for instance in shown_poses] # augmentations @@ -209,15 +205,15 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict img = tvf.to_tensor(img) - for pose in shown_poses: + for i in range(len(gt_track_ids)): + + pose = shown_poses[i] + """Check for anchor""" if self.anchor in pose: anchor = self.anchor - elif self.anchor.lower() in pose: - anchor = self.anchor.lower() - elif self.anchor.upper() in pose: - anchor = self.anchor.upper() else: + if self.verbose: warnings.warn(f"{self.anchor} not in {[key for key in pose.keys()]}! Defaulting to midpoint") anchor = "midpoint" if anchor != "midpoint": @@ -248,31 +244,21 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict padding=self.padding, ) - crop = data_utils.crop_bbox(img, bbox) - - bboxes.append(bbox) - crops.append(crop) - - stacked_crops = ( - torch.stack(crops) if crops else torch.empty((0, *crop_shape)) - ) - - instances.append( - { - "video_id": torch.tensor([label_idx]), - "img_shape": torch.tensor([img.shape]), - "frame_id": torch.tensor([frame]), - "num_detected": torch.tensor([len(bboxes)]), - "gt_track_ids": torch.tensor(gt_track_ids), - "bboxes": torch.stack(bboxes) if bboxes else torch.empty((0, 4)), - "crops": stacked_crops, - "features": torch.tensor([]), - "pred_track_ids": torch.tensor([-1 for _ in range(len(bboxes))]), - "asso_output": torch.tensor([]), - "matches": torch.tensor([]), - "traj_score": torch.tensor([]), - } - ) - - return instances + + instance = Instance(gt_track_id=gt_track_ids[i], + pred_track_id=-1, + crop=crop, + bbox=bbox + ) + + instances.append(instance) + + frame = Frame(video_id=label_idx, + frame_id=frame_ind, + img_shape=img.shape, + instances=instances + ) + frames.append(frame) + + return frames diff --git a/biogtr/inference/metrics.py b/biogtr/inference/metrics.py index d8d6386d..1a5403a2 100644 --- a/biogtr/inference/metrics.py +++ b/biogtr/inference/metrics.py @@ -1,17 +1,19 @@ """Helper functions for calculating mot metrics.""" import numpy as np import motmetrics as mm +import torch +from biogtr.data_structures import Frame from biogtr.inference.post_processing import _pairwise_iou from biogtr.inference.boxes import Boxes from typing import Union, Iterable -def get_matches(instances: list[dict]) -> tuple[dict, list, int]: +def get_matches(frames: list[Frame]) -> tuple[dict, list, int]: """Get comparison between predicted and gt trajectory labels. Args: - instances: a list of dicts where each dict corresponds to a frame and - contains the video_id, frame_id, gt labels and predicted labels + instances: a list of Frames containing the video_id, frame_id, + gt labels and predicted labels Returns: matches: a dict containing predicted and gt trajectory labels @@ -21,17 +23,15 @@ def get_matches(instances: list[dict]) -> tuple[dict, list, int]: matches = {} indices = [] - video_id = instances[0]["video_id"].item() + video_id = frames[0].video_id.item() - for idx, instance in enumerate(instances): - indices.append(instance["frame_id"].item()) - for i, gt_track_id in enumerate(instance["gt_track_ids"]): - gt_track_id = instance["gt_track_ids"][i] - pred_track_id = instance["pred_track_ids"][i] + for idx, frame in enumerate(frames): + indices.append(frame.frame_id.item()) + for gt_track_id, pred_track_id in zip(frame.get_gt_track_ids(), frame.get_pred_track_ids()): match = f"{gt_track_id} -> {pred_track_id}" if match not in matches: - matches[match] = np.full(len(instances), 0) + matches[match] = np.full(len(frames), 0) matches[match][idx] = 1 return matches, indices, video_id @@ -92,36 +92,15 @@ def get_switch_count(switches: dict) -> int: return sw_cnt -def to_track_eval(instances: list[dict]) -> dict: - """Reformats instances, the output from `sliding_inference` to be used by `TrackEval.` +def to_track_eval(frames: list[Frame]) -> dict: + """Reformats frames the output from `sliding_inference` to be used by `TrackEval.` Args: - instances: A list of dictionaries. One for each frame. An example is provided below. + instances: A list of Frames. `See biogtr.data_structures for more info.` Returns: data: A dictionary. Example provided below. - # ------------------------- An example of instances ------------------------ # - - D: embedding dimension. - N_i: number of detected instances in i-th frame of window. - - instances = [ - { - # Each dictionary is a frame. - - "frame_id": frame index int, - "num_detected": N_i, - "gt_track_ids": (N_i,), - "poses": (N_i, 13, 2), # 13 keypoints for the pose (x, y) coords. - "bboxes": (N_i, 4), # in pascal_voc unrounded unnormalized - "features": (N_i, D), # Features are deleted but can optionally be kept if need be. - "pred_track_ids": (N_i,), - }, - {}, # Frame 2. - ... - ] - # --------------------------- An example of data --------------------------- # *: number of ids for gt at every frame of the video @@ -147,30 +126,30 @@ def to_track_eval(instances: list[dict]) -> dict: similarity_scores = [] data = {} - #cos_sim = torch.nn.CosineSimilarity() + cos_sim = torch.nn.CosineSimilarity() - for fidx, instance in enumerate(instances): - gt_track_ids = instance["gt_track_ids"].cpu().numpy().tolist() - pred_track_ids = instance["pred_track_ids"].cpu().numpy().tolist() - boxes = Boxes(instance['bboxes'].cpu()) + for fidx, frame in enumerate(frames): + gt_track_ids = frame.get_gt_track_ids().cpu().numpy().tolist() + pred_track_ids = frame.get_pred_track_ids().cpu().numpy().tolist() + boxes = Boxes(frame.get_bboxes().cpu()) gt_ids.append(np.array(gt_track_ids)) track_ids.append(np.array(pred_track_ids)) - num_tracker_dets += len(instance["pred_track_ids"]) + num_tracker_dets += len(pred_track_ids) num_gt_dets += len(gt_track_ids) if not set(gt_track_ids).issubset(set(unique_gt_ids)): unique_gt_ids.extend(list(set(gt_track_ids).difference(set(unique_gt_ids)))) - eval_matrix = _pairwise_iou(boxes, boxes) -# eval_matrix = np.full((len(gt_track_ids), len(pred_track_ids)), np.nan) + #eval_matrix = _pairwise_iou(boxes, boxes) + eval_matrix = np.full((len(gt_track_ids), len(pred_track_ids)), np.nan) -# for i, feature_i in enumerate(features): -# for j, feature_j in enumerate(features): -# eval_matrix[i][j] = cos_sim( -# feature_i.unsqueeze(0), feature_j.unsqueeze(0) -# ) + for i, feature_i in enumerate(frame.get_features()): + for j, feature_j in enumerate(features): + eval_matrix[i][j] = cos_sim( + feature_i.unsqueeze(0), feature_j.unsqueeze(0) + ) # eval_matrix # pred_track_ids @@ -212,7 +191,7 @@ def to_track_eval(instances: list[dict]) -> dict: raise(e) data["tracker_ids"] = track_ids data["similarity_scores"] = similarity_scores - data["num_timesteps"] = len(instances) + data["num_timesteps"] = len(frames) return data diff --git a/biogtr/inference/track.py b/biogtr/inference/track.py index 6c051e62..8ff07cf1 100644 --- a/biogtr/inference/track.py +++ b/biogtr/inference/track.py @@ -2,6 +2,7 @@ from biogtr.config import Config from biogtr.models.gtr_runner import GTRRunner +from biogtr.data_structures import Frame from biogtr.datasets.tracking_dataset import TrackingDataset from omegaconf import DictConfig from pprint import pprint @@ -17,20 +18,21 @@ torch.set_default_device(device) -def export_trajectories(instances_pred: list[dict], save_path: str = None): +def export_trajectories(frames_pred: list[Frame], save_path: str = None): save_dict = {} frame_ids = [] X, Y = [], [] pred_track_ids = [] - for frame in instances_pred: - for i in range(frame["num_detected"]): - frame_ids.append(frame["frame_id"].item()) - bbox = frame["bboxes"][i] + for frame in frames_pred: + for i, instance in range(frame.instances): + frame_ids.append(frame.frame_id.item()) + bbox = instance.bbox y = (bbox[2] + bbox[0]) / 2 x = (bbox[3] + bbox[1]) / 2 X.append(x.item()) Y.append(y.item()) - pred_track_ids.append(frame["pred_track_ids"][i].item()) + pred_track_ids.append(instance.pred_track_id.item()) + save_dict["Frame"] = frame_ids save_dict["X"] = X save_dict["Y"] = Y @@ -60,7 +62,7 @@ def inference( for batch in preds: for frame in batch: - vid_trajectories[frame["video_id"]].append(frame) + vid_trajectories[frame.video_id].append(frame) saved = [] @@ -72,16 +74,15 @@ def inference( X, Y = [], [] pred_track_ids = [] for frame in video: - for i in range(frame["num_detected"]): - video_ids.append(frame["video_id"].item()) - frame_ids.append(frame["frame_id"].item()) - bbox = frame["bboxes"][i] - + for i, instance in frame.instances: + video_ids.append(frame.video_id.item()) + frame_ids.append(frame.frame_id.item()) + bbox = instance.bbox y = (bbox[2] + bbox[0]) / 2 x = (bbox[3] + bbox[1]) / 2 X.append(x.item()) Y.append(y.item()) - pred_track_ids.append(frame["pred_track_ids"][i].item()) + pred_track_ids.append(instance.pred_track_id.item()) save_dict["Video"] = video_ids save_dict["Frame"] = frame_ids save_dict["X"] = X diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index eaefbe26..68f16715 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -2,6 +2,7 @@ import torch import pandas as pd import warnings +from biogtr.data_structures import Frame from biogtr.models import model_utils from biogtr.inference import post_processing from biogtr.inference.boxes import Boxes @@ -59,42 +60,39 @@ def __init__( self.id_count = 0 - def __call__(self, model: GlobalTrackingTransformer, instances: list[dict], all_instances: list = None): + def __call__(self, model: GlobalTrackingTransformer, frames: list[Frame]): """Wrapper around `track` to enable `tracker()` instead of `tracker.track()`. Args: model: the pretrained GlobalTrackingTransformer to be used for inference - instances: data dict to run inference on - all_instances: list of instances from previous chunks - to stitch together full trajectory + frames: list of Frames to run inference on Returns: - instances dict populated with pred track ids and association matrix scores + List of frames containing association matrix scores and instances populated with pred track ids. """ - return self.track(model, instances, all_instances) + return self.track(model, frames) - def track(self, model: GlobalTrackingTransformer, instances: list[dict], all_instances: list = None): + def track(self, model: GlobalTrackingTransformer, frames: list[dict]): """Run tracker and get predicted trajectories. Args: model: the pretrained GlobalTrackingTransformer to be used for inference - instances: data dict to run inference on - all_instances: list of instances from previous chunks to stitch together full trajectory + frames: data dict to run inference on Returns: - instances dict populated with pred track ids and association matrix scores + List of Frames populated with pred track ids and association matrix scores """ # Extract feature representations with pre-trained encoder. _ = model.eval() - for frame in instances: - if (frame["num_detected"] > 0).item(): + for frame in frames: + if frame.has_instances(): if not self.use_vis_feats: - num_frame_instances = frame["crops"].shape[0] - frame["features"] = torch.zeros( - num_frame_instances, model.d_model - ) + for instance in frame.instances: + instance.features = torch.zeros( + 1, model.d_model + ) # frame["features"] = torch.randn( # num_frame_instances, self.model.d_model # ) @@ -102,10 +100,13 @@ def track(self, model: GlobalTrackingTransformer, instances: list[dict], all_ins # comment out to turn encoder off # Assuming the encoder is already trained or train encoder jointly. - elif 'features' not in frame or frame['features'] == None or len(frame['features']) == 0: + elif not frame.has_features(): with torch.no_grad(): - z = model.visual_encoder(frame["crops"]) - frame["features"] = z + crops = frame.get_crops() + z = model.visual_encoder(crops) + + for i, z_i in enumerate(z): + frame.instances[i].features = z_i # I feel like this chunk is unnecessary: # reid_features = torch.cat( @@ -116,7 +117,7 @@ def track(self, model: GlobalTrackingTransformer, instances: list[dict], all_ins # instances, reid_features # ) instances_pred = self.sliding_inference( - model, instances, window_size=self.window_size, all_instances=all_instances + model, frames, window_size=self.window_size ) if not self.persistent_tracking: @@ -126,35 +127,16 @@ def track(self, model: GlobalTrackingTransformer, instances: list[dict], all_ins return instances_pred - def sliding_inference(self, model: GlobalTrackingTransformer, instances, window_size, all_instances=None): + def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame], window_size: int): """Performs sliding inference on the input video (instances) with a given window size. Args: model: the pretrained GlobalTrackingTransformer to be used for inference - instances: A list of dictionaries, one dictionary for each frame. An example - is provided below. + frame: A list of Frames (See `biogtr.data_structures.Frame` for more info). window_size: An integer. Returns: - instances: A list of dictionaries, one dictionary for each frame. An example - is provided below. - # ------------------------- An example of instances ------------------------ # - D: embedding dimension. - N_i: number of detected instances in i-th frame of window. - instances = [ - { - # Each dictionary is a frame. - "frame_id": frame index int, - "num_detected": N_i, - "gt_track_ids": (N_i,), - "poses": (N_i, 13, 2), # 13 keypoints for the pose (x, y) coords. - "bboxes": (N_i, 4), # in pascal_voc unrounded unnormalized - "features": (N_i, D), # Features are deleted but can optionally be kept if need be. - "pred_track_ids": (N_i,), # Filled out after sliding_inference. - }, - {}, # Frame 2. - ... - ] + Frames: A list of Frames populated with pred_track_ids and asso_matrices """ # B: batch size. # D: embedding dimension. @@ -162,7 +144,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, instances, window_ # H: height. # W: width. - video_len = len(instances) + video_len = len(frames) id_count = self.id_count for batch_idx in range(video_len): @@ -170,59 +152,58 @@ def sliding_inference(self, model: GlobalTrackingTransformer, instances, window_ if self.verbose: warnings.warn(f"Current number of tracks is {id_count}") - if (self.persistent_tracking and instances[batch_idx]['frame_id'] == 0): #check for new video and clear queue + if (self.persistent_tracking and frames[batch_idx].frame_id == 0): #check for new video and clear queue self.track_queue.clear() self.id_count = 0 ''' Initialize tracks on first frame of video or first instance of detections. ''' - if len(self.track_queue) == 0 or sum([len(frame["pred_track_ids"]) for frame in self.track_queue]) == 0: + if len(self.track_queue) == 0 or sum([len(frame.get_pred_track_ids()) for frame in self.track_queue]) == 0: - if self.verbose: warnings.warn(f'Initializing track on batch {batch_idx} frame {instances[batch_idx]["frame_id"].item()}') + if self.verbose: warnings.warn(f'Initializing track on batch {batch_idx} frame {frames[batch_idx].frame_id.item()}') - instances[batch_idx]["pred_track_ids"] = torch.arange( - 0, len(instances[batch_idx]["bboxes"]) - ) + for i, instance in enumerate(frames[batch_idx].instances): + instance.pred_track_id = i - id_count = len(instances[batch_idx]["bboxes"]) + id_count = frames[batch_idx].num_detected - if self.verbose: warnings.warn(f'Initial tracks are {instances[batch_idx]["pred_track_ids"].cpu().tolist()}') + if self.verbose: warnings.warn(f'Initial tracks are {frames[batch_idx].get_pred_track_ids().cpu().tolist()}') - if instances[batch_idx]['num_detected'] > 0: + if frames[batch_idx].has_instances(): - self.track_queue.append(instances[batch_idx]) + self.track_queue.append(frames[batch_idx]) self.curr_gap = 0 else: self.curr_gap += 1 - if self.verbose: warnings.warn(f"No detections in frame {batch_idx}, {instances[batch_idx]['frame_id'].item()}. Skipping frame in queue. Current gap size: {self.curr_gap}") + if self.verbose: warnings.warn(f"No detections in frame {batch_idx}, {frames[batch_idx].frame_id.item()}. Skipping frame in queue. Current gap size: {self.curr_gap}") else: - if instances[batch_idx]['num_detected'] == 0: #Check if there are detections. If there are skip and increment gap count + if not frames[batch_idx].has_instances(): #Check if there are detections. If there are skip and increment gap count - instances[batch_idx]["pred_track_ids"] = torch.arange( - 0, len(instances[batch_idx]["bboxes"]) - ) + for i,instance in enumerate(frames[batch_idx].instances): + instance.pred_track_id = i + self.curr_gap += 1 - if self.verbose: warnings.warn(f"No detections in frame {batch_idx}, {instances[batch_idx]['frame_id'].item()}. Skipping frame in queue. Current gap size: {self.curr_gap}") + if self.verbose: warnings.warn(f"No detections in frame {batch_idx}, {frames[batch_idx].frame_id.item()}. Skipping frame in queue. Current gap size: {self.curr_gap}") else: #detections found. Track and reset gap counter self.curr_gap = 0 - instances_to_track = (list(self.track_queue) + [instances[batch_idx]])[-window_size:] + instances_to_track = (list(self.track_queue) + [frames[batch_idx]])[-window_size:] if len(self.track_queue) == self.track_queue.maxlen: tracked_frame = self.track_queue.pop() tracked_frame["tracked"] = True - self.track_queue.append(instances[batch_idx]) + self.track_queue.append(frames[batch_idx]) query_ind = min(window_size - 1, len(instances_to_track) - 1) - instances[batch_idx], id_count = self._run_global_tracker( + frames[batch_idx], id_count = self._run_global_tracker( model, instances_to_track, query_frame=query_ind, @@ -264,17 +245,16 @@ def sliding_inference(self, model: GlobalTrackingTransformer, instances, window_ # if "tracked" in frame.keys(): # frame['features'] = frame['features'].cpu() self.id_count = id_count - return instances + return frames - def _run_global_tracker(self, model: GlobalTrackingTransformer, instances, query_frame, id_count, overlap_thresh, mult_thresh): + def _run_global_tracker(self, model: GlobalTrackingTransformer, frames, query_frame, id_count, overlap_thresh, mult_thresh): """Run_global_tracker performs the actual tracking. Uses Hungarian algorithm to do track assigning. Args: model: the pretrained GlobalTrackingTransformer to be used for inference - instances: A list of dictionaries, one dictionary for each frame. An example - is provided below. + frames: A list of Frames containing reid features. See `biogtr.data_structures` for more info. query_frame: An integer for the query frame within the window of instances. id_count: The count of total identities so far. overlap_thresh: A float number between 0 and 1 specifying how much @@ -283,32 +263,11 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, instances, query This is not functional as of now. Returns: - instances: The exact list of dictionaries as before but with assigned track ids + frames: The exact list of frames as before but with assigned track ids and new track ids for the query frame. Refer to the example for the structure. id_count: An integer for the updated identity count so far. - # ------------------------- An example of instances ------------------------ # - NOTE: This instances variable is the window subset of the instances variable in sliding_inference. - *: each item in instances is a frame in the window. So it follows - that each frame in the window has * detected instances. - D: embedding dimension. - N_i: number of detected instances in i-th frame of window. - window_size: length of window. - The features in instances can be of shape (2 to window_size, *, D) when stacked together. - instances = [ - { - # Each dictionary is a frame. - "frame_id": frame index int, - "num_detected": N_i, - "gt_track_ids": (N_i,), - "poses": (N_i, 13, 2), # 13 keypoints for the pose (x, y) coords. - "bboxes": (N_i, 4), # in pascal_voc unrounded unnormalized - "features": (N_i, D), - "pred_track_ids": (N_i,), # Before assignnment, these are all -1. - }, - ... - ] """ - # *: each item in instances is a frame in the window. So it follows + # *: each item in frames is a frame in the window. So it follows # that each frame in the window has * detected instances. # D: embedding dimension. # total_instances: number of instances in the window. @@ -322,24 +281,20 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, instances, query # Number of instances in each frame of the window. # E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window. - # print([frame['frame_id'].item() for frame in instances]) - # print([frame['frame_id'].item() for frame in instances]) - # print([frame['pred_track_ids'] for frame in instances]) + _ = model.eval() - instances_per_frame = [frame["num_detected"] for frame in instances] + instances_per_frame = [frame.num_detected for frame in frames] total_instances, window_size = sum(instances_per_frame), len(instances_per_frame) # Number of instances in window; length of window. - reid_features = torch.cat([frame["features"] for frame in instances], dim=0)[ + reid_features = torch.cat([frame.get_features() for frame in frames], dim=0)[ None ] # (1, total_instances, D=512) # (L=1, n_query, total_instances) with torch.no_grad(): - if model.transformer.return_embedding: - asso_output, embed = model(instances, query_frame=query_frame) - instances[query_frame]["embeddings"] = embed - else: - asso_output = model(instances, query_frame=query_frame) + asso_output, embed = model(frames, query_frame=query_frame) + # if model.transformer.return_embedding: + # frames[query_frame].embeddings = embed TODO add embedding to Instance Object # if query_frame == 1: # print(asso_output) asso_output = asso_output[-1].split(instances_per_frame, dim=1) # (window_size, n_query, N_i) @@ -347,11 +302,9 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, instances, query asso_output = torch.cat(asso_output, dim=1).cpu() # (n_query, total_instances) try: - n_query = instances[query_frame][ - "num_detected" - ] # Number of instances in the current/query frame. + n_query = frames[query_frame].num_detected # Number of instances in the current/query frame. except Exception as e: - print(len(instances), query_frame, instances[-1]) + print(len(frames), query_frame, frames[-1]) raise(e) n_nonquery = ( @@ -360,19 +313,19 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, instances, query try: instance_ids = torch.cat( - [x["pred_track_ids"] for batch_idx, x in enumerate(instances) if batch_idx != query_frame], dim=0 + [x.get_pred_track_ids() for batch_idx, x in enumerate(frames) if batch_idx != query_frame], dim=0 ).view( n_nonquery ) # (n_nonquery,) except Exception as e: - print(instances) + print(frames) raise(e) query_inds = [x for x in range(sum(instances_per_frame[:query_frame]), sum(instances_per_frame[: query_frame + 1]))] nonquery_inds = [i for i in range(total_instances) if i not in query_inds] asso_nonquery = asso_output[:, nonquery_inds] # (n_query, n_nonquery) - pred_boxes, _ = model_utils.get_boxes_times(instances) + pred_boxes, _ = model_utils.get_boxes_times(frames) query_boxes = pred_boxes[query_inds] # n_k x 4 nonquery_boxes = pred_boxes[nonquery_inds] #n_nonquery x 4 # TODO: Insert postprocessing. @@ -392,11 +345,14 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, instances, query traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj) - instances[query_frame]["decay_time_traj_score"] = pd.DataFrame( + decay_time_traj_score = pd.DataFrame( deepcopy((traj_score).numpy()), columns=unique_ids.cpu().numpy() ) - instances[query_frame]["decay_time_traj_score"].index.name = "Current Frame Instances" - instances[query_frame]["decay_time_traj_score"].columns.name = "Unique IDs" + + decay_time_traj_score.index.name = "Current Frame Instances" + decay_time_traj_score.columns.name = "Unique IDs" + + frames[query_frame].add_traj_score("decay_time", decay_time_traj_score) ################################################################################ # with iou -> combining with location in tracker, they set to True @@ -408,7 +364,7 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, instances, query # n_nonquery, device=id_inds.device)[:, None]).max(dim=0)[1] # n_traj last_inds = ( - id_inds * torch.arange(n_nonquery[0], device=id_inds.device)[:, None] + id_inds * torch.arange(n_nonquery, device=id_inds.device)[:, None] ).max(dim=0)[ 1 ] # M @@ -451,14 +407,18 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, instances, query track_ids[i] = id_count id_count += 1 - instances[query_frame]["matches"] = (match_i, match_j) - instances[query_frame]["pred_track_ids"] = track_ids - instances[query_frame]["final_traj_score"] = pd.DataFrame( + frames[query_frame].matches = (match_i, match_j) + + for instance, track_id in zip(frames[query_frame].instances, track_ids): + instance.pred_track_id = track_id + + final_traj_score = pd.DataFrame( deepcopy((traj_score).numpy()), columns=unique_ids.cpu().numpy() ) - instances[query_frame]["final_traj_score"].index.name = "Current Frame Instances" - instances[query_frame]["final_traj_score"].columns.name = "Unique IDs" + final_traj_score.index.name = "Current Frame Instances" + final_traj_score.columns.name = "Unique IDs" - self.track_queue.append(instances[query_frame]) + frames[query_frame].add_traj_score("final", final_traj_score) + self.track_queue.append(frames[query_frame]) - return instances[query_frame], id_count + return frames[query_frame], id_count diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index 373d368b..b4bb1b2e 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -1,11 +1,11 @@ """Module containing GTR model used for training.""" from biogtr.models.transformer import Transformer from biogtr.models.visual_encoder import VisualEncoder +from biogtr.data_structures import Frame from torch import nn # todo: do we want to handle params with configs already here? - class GlobalTrackingTransformer(nn.Module): """Modular GTR model composed of visual encoder + transformer used for tracking.""" @@ -99,31 +99,28 @@ def __init__( def forward( self, - instances: list[dict], - all_instances: list[dict] = None, - query_frame: int = None, + frames: list[Frame], + query_frame: int = None ): """Forward pass of GTR Model to get asso matrix. Args: - instances: List of dicts from chunk containing crops of objects + gt label info - all_instances: List of dicts containing crops of objects + gt label info. Used for stitching together full trajectory + frames: List of Frames from chunk containing crops of objects + gt label info query_frame: Frame index used as query for self attention. Only used in sliding inference where query frame is the last frame in the window. Returns: An N_T x N association matrix """ # Extract feature representations with pre-trained encoder. - for frame in instances: - if (frame["num_detected"] > 0).item(): - if "features" not in frame.keys() or frame['features'] == None or len(frame["features"]) == 0: - z = self.visual_encoder(frame["crops"]) - frame["features"] = z + for frame in frames: + if frame.has_instances(): + if not frame.has_features(): + crops = frame.get_crops() + z = self.visual_encoder(crops) - # Extract association matrix with transformer. - if self.transformer.return_embedding: - asso_preds, emb = self.transformer(instances, query_frame=query_frame) - else: - asso_preds = self.transformer(instances, query_frame=query_frame) + for i, z_i in enumerate(z): + frame.instances[i].features = z_i - return (asso_preds, emb) if self.transformer.return_embedding else asso_preds + asso_preds, emb = self.transformer(frames, query_frame=query_frame) + + return asso_preds, emb diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index b4457acd..0a71a4da 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -2,10 +2,11 @@ from copy import deepcopy from typing import Dict, List, Tuple, Iterable from pytorch_lightning import loggers +from biogtr.data_structures import Frame import torch -def get_boxes_times(instances: List[Dict]) -> Tuple[torch.Tensor, torch.Tensor]: +def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: """Extracts the bounding boxes and frame indices from the input list of instances. Args: @@ -17,10 +18,10 @@ def get_boxes_times(instances: List[Dict]) -> Tuple[torch.Tensor, torch.Tensor]: indices, respectively. """ boxes, times = [], [] - _, h, w = instances[0]["img_shape"].flatten() + _, h, w = frames[0].img_shape.flatten() - for fidx, instance in enumerate(instances): - bbox = deepcopy(instance["bboxes"]) + for fidx, frame in enumerate(frames): + bbox = deepcopy(frame.get_bboxes()) bbox[:, [0, 2]] /= w bbox[:, [1, 3]] /= h diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index 91274ddb..32b8493a 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -11,7 +11,7 @@ * added fixed embeddings over boxes """ - +from biogtr.data_structures import Frame from biogtr.models.attention_head import ATTWeightHead from biogtr.models.embedding import Embedding from biogtr.models.model_utils import get_boxes_times @@ -163,11 +163,11 @@ def _reset_parameters(self): if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, instances, query_frame=None): + def forward(self, frames: list[Frame], query_frame: int=None): """A forward pass through the transformer and attention head. Args: - instances: A list of dictionaries, one dictionary for each frame + frames: A list of Frames (See `biogtr.data_structures.Frame for more info.) query_frame: An integer (k) specifying the frame within the window to be queried. Returns: @@ -175,34 +175,22 @@ def forward(self, instances, query_frame=None): L: number of decoder blocks n_query: number of instances in current query/frame total_instances: number of instances in window - - # ------------------------- An example of instances ------------------------ # - instances = [ - { - # Each dictionary is a frame. - "frame_id": frame index int, - "num_detected": N_i, # num of detected instances in i-th frame - "bboxes": (N_i, 4), # in pascal_voc unrounded unnormalized - "features": (N_i, embed_dim), # embed_dim = embedding dimension - ... - }, - ... - ] """ reid_features = torch.cat( - [frame["features"] for frame in instances], dim=0 + [frame.get_features() for frame in frames], dim=0 ).unsqueeze(0) - - window_length = len(instances) - instances_per_frame = [frame["num_detected"] for frame in instances] + + window_length = len(frames) + instances_per_frame = [frame.num_detected for frame in frames] total_instances = sum(instances_per_frame) embed_dim = reid_features.shape[-1] - + + #print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}') if self.embedding_meta: kwargs = self.embedding_meta.get("kwargs", {}) - pred_box, pred_time = get_boxes_times(instances) # total_instances x 4 - + pred_box, pred_time = get_boxes_times(frames) # total_instances x 4 + embedding_type = self.embedding_meta["embedding_type"] if "temp" in embedding_type: @@ -269,7 +257,7 @@ def forward(self, instances, query_frame=None): asso_output.append(self.attn_head(x, memory).view(n_query, total_instances)) # (L=1, n_query, total_instances) - return (asso_output, pos_emb) if self.return_embedding else asso_output + return (asso_output, pos_emb) if self.return_embedding else (asso_output, None) class TransformerEncoder(nn.Module): diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 844efd75..7ebb83f1 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -54,8 +54,8 @@ def test_sleap_dataset(two_flies): instances = next(iter(train_ds)) assert len(instances) == clip_length - assert len(instances[0]["gt_track_ids"]) == 2 - assert len(instances[0]["gt_track_ids"]) == instances[0]["num_detected"].item() + assert len(instances[0].get_gt_track_ids()) == 2 + assert len(instances[0].get_gt_track_ids()) == instances[0].num_detected chunk_frac = 0.5 @@ -129,8 +129,8 @@ def test_icy_dataset(ten_icy_particles): instances = next(iter(train_ds)) assert len(instances) == clip_length - assert len(instances[0]["gt_track_ids"]) == 10 - assert len(instances[0]["gt_track_ids"]) == instances[0]["num_detected"].item() + assert len(instances[0].get_gt_track_ids()) == 10 + assert len(instances[0].get_gt_track_ids()) == instances[0].num_detected def test_trackmate_dataset(trackmate_lysosomes): @@ -153,8 +153,8 @@ def test_trackmate_dataset(trackmate_lysosomes): instances = next(iter(train_ds)) assert len(instances) == clip_length - assert len(instances[0]["gt_track_ids"]) == 26 - assert len(instances[0]["gt_track_ids"]) == instances[0]["num_detected"].item() + assert len(instances[0].get_gt_track_ids()) == 26 + assert len(instances[0].get_gt_track_ids()) == instances[0].num_detected def test_isbi_dataset(isbi_microtubules, isbi_receptors): @@ -182,8 +182,8 @@ def test_isbi_dataset(isbi_microtubules, isbi_receptors): instances = next(iter(train_ds)) assert len(instances) == clip_length - assert len(instances[0]["gt_track_ids"]) == num_objects - assert len(instances[0]["gt_track_ids"]) == instances[0]["num_detected"].item() + assert len(instances[0].get_gt_track_ids()) == num_objects + assert len(instances[0].get_gt_track_ids()) == instances[0].num_detected def test_cell_tracking_dataset(cell_tracking): @@ -206,11 +206,11 @@ def test_cell_tracking_dataset(cell_tracking): instances = next(iter(train_ds)) - gt_track_ids_1 = instances[0]["gt_track_ids"] + gt_track_ids_1 = instances[0].get_gt_track_ids() assert len(instances) == clip_length assert len(gt_track_ids_1) == 30 - assert len(gt_track_ids_1) == instances[0]["num_detected"].item() + assert len(gt_track_ids_1) == instances[0].num_detected # fall back to using np.unique when gt_list not available train_ds = CellTrackingDataset( @@ -223,11 +223,11 @@ def test_cell_tracking_dataset(cell_tracking): instances = next(iter(train_ds)) - gt_track_ids_2 = instances[0]["gt_track_ids"] + gt_track_ids_2 = instances[0].get_gt_track_ids() assert len(instances) == clip_length assert len(gt_track_ids_2) == 30 - assert len(gt_track_ids_2) == instances[0]["num_detected"].item() + assert len(gt_track_ids_2) == instances[0].num_detected assert gt_track_ids_1.all() == gt_track_ids_2.all() @@ -386,8 +386,8 @@ def test_augmentations(two_flies, ten_icy_particles): augs_instances = next(iter(augs_ds)) - a = no_augs_instances[0]["crops"] - b = augs_instances[0]["crops"] + a = no_augs_instances[0].get_crops() + b = augs_instances[0].get_crops() assert not torch.all(a.eq(b)) @@ -433,7 +433,7 @@ def test_augmentations(two_flies, ten_icy_particles): augs_instances = next(iter(augs_ds)) - a = no_augs_instances[0]["crops"] - b = augs_instances[0]["crops"] + a = no_augs_instances[0].get_crops() + b = augs_instances[0].get_crops() assert not torch.all(a.eq(b)) \ No newline at end of file diff --git a/tests/test_inference.py b/tests/test_inference.py index 93f57438..288f1792 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -2,6 +2,7 @@ import torch import pytest import numpy as np +from biogtr.data_structures import Frame, Instance from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer from biogtr.inference.tracker import Tracker from biogtr.inference import post_processing @@ -18,20 +19,19 @@ def test_tracker(): num_detected = 2 img_shape = (1, 128, 128) test_frame = 1 - instances = [] + frames = [] for i in range(num_frames): - instances.append( - { - "frame_id": torch.tensor(i), - "img_shape": torch.tensor(img_shape), - "num_detected": torch.tensor([num_detected]), - "crops": torch.rand(size=(num_detected, 1, 64, 64)), - "bboxes": torch.rand(size=(num_detected, 4)), - "gt_track_ids": torch.arange(num_detected), - "pred_track_ids": torch.tensor([-1] * num_detected), - } - ) + instances = [] + for j in range(num_detected): + instances.append(Instance(gt_track_id=j, + pred_track_id=-1, + bbox=torch.rand(size=(1, 4)), + crop = torch.rand(size=(1, 1, 64, 64)))) + frames.append(Frame(video_id=0, + frame_id=i, + img_shape=img_shape, + instances=instances)) embedding_meta = { "embedding_type": "fixed_pos", @@ -59,15 +59,16 @@ def test_tracker(): tracker = Tracker(**tracking_cfg) - instances_pred = tracker(tracking_transformer, instances) + frames_pred = tracker(tracking_transformer, frames) + print(frames_pred[test_frame]) asso_equals = ( - instances_pred[test_frame]["decay_time_traj_score"].to_numpy() - == instances_pred[test_frame]["final_traj_score"].to_numpy() + frames_pred[test_frame].get_traj_score("decay_time").to_numpy() + == frames_pred[test_frame].get_traj_score("final").to_numpy() ).all() assert asso_equals - assert len(instances_pred[test_frame]["pred_track_ids"] == num_detected) + assert (len(frames_pred[test_frame].get_pred_track_ids()) == num_detected) #@pytest.mark.parametrize("set_default_device", ["cpu"], indirect=True) @@ -152,45 +153,33 @@ def test_metrics(): num_frames = 3 num_detected = 3 n_batches = 1 - instances_pred = [] + batches = [] for i in range(n_batches): + frames_pred = [] for j in range(num_frames): - bboxes = torch.tensor(np.random.uniform(size=(num_detected, 4))) - bboxes[:, -2:] += 1 - instances_pred.append( - - { - "video_id": torch.tensor(0), - "frame_id": torch.tensor(j), - "num_detected": torch.tensor([num_detected]), - "bboxes": bboxes, - "gt_track_ids": torch.arange(num_detected), - "pred_track_ids": torch.arange(num_detected), - } - ) - instances_mm = metrics.to_track_eval(instances_pred) - clear_mot = metrics.get_pymotmetrics(instances_mm) - - matches, indices, _ = metrics.get_matches(instances_pred) - - switches = metrics.get_switches(matches, indices) - - sw_cnt = metrics.get_switch_count(switches) - - assert sw_cnt == clear_mot["num_switches"] == 0, (sw_cnt, clear_mot["num_switches"]) - - instances_pred[1]['pred_track_ids'] = torch.tensor([1,2,0]) - instances_pred[2]['pred_track_ids'] = torch.tensor([2,0,1]) - - instances_mm = metrics.to_track_eval(instances_pred) - clear_mot = metrics.get_pymotmetrics(instances_mm) - - matches, indices, _ = metrics.get_matches(instances_pred) - - switches = metrics.get_switches(matches, indices) - - sw_cnt = metrics.get_switch_count(switches) - - assert sw_cnt == clear_mot["num_switches"] == 6, (instances_pred[1]['gt_track_ids'],instances_pred[1]['pred_track_ids'], sw_cnt, clear_mot["num_switches"]) - + instances_pred = [] + for k in range(num_detected): + bboxes = torch.tensor(np.random.uniform(size=(num_detected, 4))) + bboxes[:, -2:] += 1 + instances_pred.append(Instance(gt_track_id=k, + pred_track_id=k, + bbox=torch.randn((1,4)) + )) + frames_pred.append(Frame(video_id=0, + frame_id=j, + instances=instances_pred)) + batches.append(frames_pred) + + for batch in batches: + instances_mm = metrics.to_track_eval(batch) + clear_mot = metrics.get_pymotmetrics(instances_mm) + + matches, indices, _ = metrics.get_matches(batch) + + switches = metrics.get_switches(matches, indices) + + sw_cnt = metrics.get_switch_count(switches) + + assert sw_cnt == clear_mot["num_switches"] == 0, (sw_cnt, clear_mot["num_switches"]) + \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py index f85fdfb0..425f00df 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,6 +2,7 @@ import pytest import torch import numpy as np +from biogtr.data_structures import Frame, Instance from biogtr.models.attention_head import MLP, ATTWeightHead from biogtr.models.embedding import Embedding from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer @@ -207,20 +208,18 @@ def test_transformer_basic(): feature_dim_attn_head=feats, ) - instances = [] + frames = [] for i in range(num_frames): - instances.append( - { - "frame_id": torch.tensor(i), - "img_shape": torch.tensor(img_shape), - "num_detected": torch.tensor([num_detected]), - "bboxes": torch.rand(size=(num_detected, 4)), - "features": torch.rand(size=(num_detected, feats)), - } - ) + instances = [] + for j in range(num_detected): + instances.append(Instance(bbox=torch.rand(size=(1, 4)), + features=torch.rand(size=(1, feats)))) + frames.append(Frame(video_id = 0, frame_id=i, + instances=instances)) + - asso_preds = transformer(instances) + asso_preds,_ = transformer(frames) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 @@ -270,18 +269,15 @@ def test_transformer_embedding(): num_detected = 10 img_shape = (1, 50, 50) - instances = [] + frames = [] for i in range(num_frames): - instances.append( - { - "frame_id": torch.tensor(i), - "img_shape": torch.tensor(img_shape), - "num_detected": torch.tensor([num_detected]), - "bboxes": torch.rand(size=(num_detected, 4)), - "features": torch.rand(size=(num_detected, feats)), - } - ) + instances = [] + for j in range(num_detected): + instances.append(Instance(bbox=torch.rand(size=(1, 4)), + features=torch.rand(size=(1, feats)))) + frames.append(Frame(video_id = 0, frame_id=i, + instances=instances)) embedding_meta = { "embedding_type": "learned_pos_temp", @@ -302,7 +298,7 @@ def test_transformer_embedding(): return_embedding=True, ) - asso_preds, embedding = transformer(instances) + asso_preds, embedding = transformer(frames) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 assert embedding.size() == (num_detected * num_frames, 1, feats) @@ -315,18 +311,18 @@ def test_tracking_transformer(): num_detected = 20 img_shape = (1, 128, 128) - instances = [] + frames = [] for i in range(num_frames): - instances.append( - { - "frame_id": torch.tensor(i), - "img_shape": torch.tensor(img_shape), - "num_detected": torch.tensor([num_detected]), - "crops": torch.rand(size=(num_detected, 1, 64, 64)), - "bboxes": torch.rand(size=(num_detected, 4)), - } - ) + instances = [] + for j in range(num_detected): + instances.append(Instance(bbox=torch.rand(size=(1, 4)), + crop=torch.rand(size=(1, 1, 64, 64)) + )) + frames.append(Frame(video_id=0, + frame_id=i, + img_shape=img_shape, + instances=instances)) embedding_meta = { "embedding_type": "fixed_pos", @@ -347,7 +343,7 @@ def test_tracking_transformer(): return_embedding=True, ) - asso_preds, embedding = tracking_transformer(instances) + asso_preds, embedding = tracking_transformer(frames) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 assert embedding.size() == (num_detected * num_frames, 1, feats) From a63dbf5782f72ed52b623e859f470b8ba43c7492 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 9 Nov 2023 17:39:06 -0800 Subject: [PATCH 05/40] implement track_local_queues --- biogtr/inference/track_queue.py | 296 ++++++++++++++++++++++++++++++++ biogtr/inference/tracker.py | 271 +++++++++++++---------------- 2 files changed, 417 insertions(+), 150 deletions(-) create mode 100644 biogtr/inference/track_queue.py diff --git a/biogtr/inference/track_queue.py b/biogtr/inference/track_queue.py new file mode 100644 index 00000000..d8eb9a9c --- /dev/null +++ b/biogtr/inference/track_queue.py @@ -0,0 +1,296 @@ +"""Module handling sliding window tracking.""" + +import warnings +from biogtr.data_structures import Frame +from collections import deque + + +class TrackQueue: + """Class handling track local queue system for sliding window. + Each trajectory has its own deque based queue of size `window_size - 1`. + Elements of the queue are Instance objects that have already been tracked + and will be compared against later frames for assignment. + """ + + def __init__(self, window_size: int, max_gap: int = 1, verbose: bool = False): + """Initialize track queue. + + Args: + window_size: The number of instances per trajectory allowed in the + queue to be compared against. + max_gap: The number of consecutive frames a trajectory can fail to + appear in before terminating the track. + verbose: Whether to print info during operations. + """ + self._window_size = window_size + self._queues = {} + self._max_gap = max_gap + self._curr_gap = {} + if self._max_gap >= 0 and self._max_gap <= self._window_size: + self._max_gap = self._window_size + self._curr_track = -1 + self._verbose = verbose + + def __len__(self): + """Length of the queue. + + Returns: + The total number of instances in every sub-queue. + """ + return sum([len(queue) for queue in self._queues.values()]) + + def __repr__(self): + """The string representation of the TrackQueue. + + Returns: + The string representation of the current state of the queue. + """ + return ( + "TrackQueue(" + f"window_size={self.window_size}, " + f"max_gap={self.max_gap}, " + f"n_tracks={self.n_tracks}, " + f"curr_track={self.curr_track}, " + f"queues={[(key,len(queue)) for key, queue in self._queues.items()]}, " + f"curr_gap:{self._curr_gap}" + ")" + ) + + @property + def window_size(self) -> int: + """The maximum number of instances allowed in a sub-queue to be + compared against. + + Returns: + An int representing The maximum number of instances allowed in a + sub-queue to be compared against. + """ + return self._window_size + + @window_size.setter + def window_size(self, window_size: int) -> None: + """Function for setting the window size of the queue. + + Args: + window_size: An int representing The maximum number of instances + allowed in a sub-queue to be compared against. + """ + self._window_size = window_size + + @property + def max_gap(self) -> int: + """The maximum number of consecutive frames an trajectory can fail to + appear before termination. + + Returns: + An int representing the maximum number of consecutive frames an trajectory can fail to + appear before termination. + """ + return self._max_gap + + @max_gap.setter + def max_gap(self, max_gap: int) -> None: + """Function for setting the max consecutive frame gap allowed for a trajectory. + + Args: + max_gap: An int representing the maximum number of consecutive frames an trajectory can fail to + appear before termination. + """ + self._max_gap = max_gap + + @property + def curr_track(self) -> int: + """The newest *created* trajectory in the queue. + + Returns: + The latest *created* trajectory in the queue. + """ + return self._curr_track + + @curr_track.setter + def curr_track(self, curr_track: int) -> None: + """Function for setting the newest *created* trajectory in the queue. + + Args: + curr_track: The latest *created* trajectory in the queue. + """ + self._curr_track = curr_track + + @property + def n_tracks(self) -> int: + """The current number of trajectories in the queue. + + Returns: + An int representing the current number of trajectories in the queue. + """ + return len(self._queues) + + @property + def tracks(self) -> list: + """A list of the track ids currently in the queue. + + Returns: + A list containing the track ids currently in the queue. + """ + return list(self._queues.keys()) + + @property + def verbose(self) -> bool: + """Whether or not to print outputs along operations. + Mostly used for debugging. + + Returns: + A boolean representing whether or not printing is turned on. + """ + return self._verbose + + @verbose.setter + def verbose(self, verbose: bool) -> None: + """Function for turning on/off printing. + + Args: + verbose: A boolean representing whether printing should be on or off. + """ + self._verbose = verbose + + def end_tracks(self, track_id=None): + """Function for terminating tracks and removing them from the queue. + + Args: + track_id: The index of the trajectory to be ended and removed. + If `None` then then every trajectory is removed and the track queue is reset. + Returns: + True if the track is successively removed, otherwise False. + (ie if the track doesn't exist in the queue.) + """ + if track_id is None: + self._queues = {} + self._curr_gap = {} + self.curr_track = -1 + else: + try: + self._queues.pop(track_id) + self._curr_gap.pop(track_id) + except Exception as e: + print(f"Unable to end track due to {e}") + return False + return True + + def add_frame(self, frame: Frame) -> None: + """Function for adding frames to the queue. + + Each instance from the frame is added to the queue according to its pred_track_id. + If the corresponding trajectory is not already in the queue then create a new queue for the track. + + Args: + frame: A Frame object containing instances that have already been tracked. + """ + if frame.num_detected == 0: # only add frames with instances. + return + vid_id = frame.video_id.item() + frame_id = frame.frame_id.item() + img_shape = frame.img_shape + frame_meta = (vid_id, frame_id, img_shape.cpu().tolist()) + + pred_tracks = [] + for instance in frame.instances: + pred_track_id = instance.pred_track_id.item() + pred_tracks.append(pred_track_id) + + if pred_track_id not in self._queues.keys(): + self._queues[pred_track_id] = deque( + [(*frame_meta, instance)], maxlen=self.window_size - 1 + ) # dumb work around to retain `img_shape` + self.curr_track = pred_track_id + + if self.verbose: + warnings.warn( + f"New track = {pred_track_id} on frame {frame_id}! Current number of tracks = {self.n_tracks}" + ) + + else: + self._queues[pred_track_id].append((*frame_meta, instance)) + self.increment_gaps( + pred_tracks + ) # should this be done in the tracker or the queue? + + def collate_tracks( + self, track_ids: list[int] = None, device: str = None + ) -> list[Frame]: + """Merge queues into a single list of Frames containing corresponding instances. + + Args: + track_ids: A list of trajectorys to merge. If None, then merge all + queues, otherwise filter queues by track_ids then merge. + device: A str representation of the device the frames should be on after merging + since all instances in the queue are kept on the cpu. + Returns: + A sorted list of Frame objects from which each instance came from, + containing the corresponding instances. + """ + if len(self._queues) == 0: + return [] + + frames = {} + + tracks_to_convert = ( + {track: queue for track, queue in self._queues if track in track_ids} + if track_ids is not None + else self._queues + ) + for track, instances in tracks_to_convert.items(): + for video_id, frame_id, img_shape, instance in instances: + if (video_id, frame_id) not in frames.keys(): + frame = Frame( + video_id, frame_id, img_shape=img_shape, instances=[instance] + ) + frames[(video_id, frame_id)] = frame + else: + frames[(video_id, frame_id)].instances.append(instance) + return [frames[frame].to(device) for frame in sorted(frames.keys())] + + def increment_gaps(self, pred_track_ids: list[int]) -> dict[int, bool]: + """Function for keeping track of number of consecutive frames each + trajectory has been missing from the queue. + + If a trajectory has exceeded the `max_gap` then terminate the track and remove it from the queue. + + Args: + pred_track_ids: A list of track_ids to be matched against the trajectories in the queue. + If a trajectory is in `pred_track_ids` then its gap counter is reset, + otherwise its incremented by 1. + Returns: + A dictionary containing the trajectory id and a boolean value representing + whether or not it has exceeded the max allowed gap and been + terminated. + """ + exceeded_gap = {} + + for track in pred_track_ids: + if track not in self._curr_gap: + self._curr_gap[track] = 0 + + for track in self._curr_gap: + if track not in pred_track_ids: + self._curr_gap[track] += 1 + if self.verbose: + warnings.warn( + f"Track {track} has not been seen for {self._curr_gap[track]} frames." + ) + else: + self._curr_gap[track] = 0 + if self._curr_gap[track] >= self.max_gap: + exceeded_gap[track] = True + else: + exceeded_gap[track] = False + + for track, gap_exceeded in exceeded_gap.items(): + if gap_exceeded: + if self.verbose: + warnings.warn( + f"Track {track} has not been seen for {self._curr_gap[track]} frames! Terminating Track...Current number of tracks = {self.n_tracks}." + ) + self._queues.pop(track) + self._curr_gap.pop(track) + + return exceeded_gap diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 68f16715..e7a2668a 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -3,13 +3,13 @@ import pandas as pd import warnings from biogtr.data_structures import Frame +from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer from biogtr.models import model_utils +from biogtr.inference.track_queue import TrackQueue from biogtr.inference import post_processing from biogtr.inference.boxes import Boxes -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer from scipy.optimize import linear_sum_assignment from copy import deepcopy -from collections import deque class Tracker: @@ -26,7 +26,7 @@ def __init__( max_center_dist: float = None, persistent_tracking: bool = False, max_gap: int = -1, - verbose = False + verbose=False, ): """Initialize a tracker to run inference. @@ -41,9 +41,9 @@ def __init__( max_center_dist: distance threshold for filtering trajectory score matrix persistent_tracking: whether to keep a buffer across chunks or not """ - - self.window_size = window_size - self.track_queue = deque(maxlen=self.window_size) + self.track_queue = TrackQueue( + window_size=window_size, max_gap=max_gap, verbose=verbose + ) self.use_vis_feats = use_vis_feats self.overlap_thresh = overlap_thresh self.mult_thresh = mult_thresh @@ -52,13 +52,6 @@ def __init__( self.max_center_dist = max_center_dist self.persistent_tracking = persistent_tracking self.verbose = verbose - - self.max_gap = max_gap - self.curr_gap = 0 - if self.max_gap >=0 and self.max_gap <= self.window_size: - self.max_gap = self.window_size - - self.id_count = 0 def __call__(self, model: GlobalTrackingTransformer, frames: list[Frame]): """Wrapper around `track` to enable `tracker()` instead of `tracker.track()`. @@ -68,7 +61,7 @@ def __call__(self, model: GlobalTrackingTransformer, frames: list[Frame]): frames: list of Frames to run inference on Returns: - List of frames containing association matrix scores and instances populated with pred track ids. + List of frames containing association matrix scores and instances populated with pred track ids. """ return self.track(model, frames) @@ -82,7 +75,7 @@ def track(self, model: GlobalTrackingTransformer, frames: list[dict]): Returns: List of Frames populated with pred track ids and association matrix scores """ -# Extract feature representations with pre-trained encoder. + # Extract feature representations with pre-trained encoder. _ = model.eval() @@ -90,9 +83,7 @@ def track(self, model: GlobalTrackingTransformer, frames: list[dict]): if frame.has_instances(): if not self.use_vis_feats: for instance in frame.instances: - instance.features = torch.zeros( - 1, model.d_model - ) + instance.features = torch.zeros(1, model.d_model) # frame["features"] = torch.randn( # num_frame_instances, self.model.d_model # ) @@ -119,15 +110,17 @@ def track(self, model: GlobalTrackingTransformer, frames: list[dict]): instances_pred = self.sliding_inference( model, frames, window_size=self.window_size ) - + if not self.persistent_tracking: - if self.verbose: warnings.warn(f'Clearing Queue after tracking') - self.track_queue.clear() - self.id_count = 0 - + if self.verbose: + warnings.warn(f"Clearing Queue after tracking") + self.track_queue.end_tracks() + return instances_pred - def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame], window_size: int): + def sliding_inference( + self, model: GlobalTrackingTransformer, frames: list[Frame], window_size: int + ): """Performs sliding inference on the input video (instances) with a given window size. Args: @@ -144,110 +137,60 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame # H: height. # W: width. - video_len = len(frames) - id_count = self.id_count - - for batch_idx in range(video_len): - - if self.verbose: - warnings.warn(f"Current number of tracks is {id_count}") - - if (self.persistent_tracking and frames[batch_idx].frame_id == 0): #check for new video and clear queue - self.track_queue.clear() - self.id_count = 0 - - ''' + for batch_idx, frame_to_track in frames: + tracked_frames = self.track_queue.collate_tracks() + if self.verbose: + warnings.warn( + f"Current number of tracks is {self.track_queue.n_tracks}" + ) + + if ( + self.persistent_tracking and frame_to_track.frame_id == 0 + ): # check for new video and clear queue + if self.verbose: + warnings.warn("New Video! Resetting Track Queue.") + self.track_queue.end_tracks() + + """ Initialize tracks on first frame of video or first instance of detections. - ''' - if len(self.track_queue) == 0 or sum([len(frame.get_pred_track_ids()) for frame in self.track_queue]) == 0: - - if self.verbose: warnings.warn(f'Initializing track on batch {batch_idx} frame {frames[batch_idx].frame_id.item()}') - - for i, instance in enumerate(frames[batch_idx].instances): - instance.pred_track_id = i - - id_count = frames[batch_idx].num_detected - - if self.verbose: warnings.warn(f'Initial tracks are {frames[batch_idx].get_pred_track_ids().cpu().tolist()}') - - if frames[batch_idx].has_instances(): - - self.track_queue.append(frames[batch_idx]) - self.curr_gap = 0 - else: - self.curr_gap += 1 - if self.verbose: warnings.warn(f"No detections in frame {batch_idx}, {frames[batch_idx].frame_id.item()}. Skipping frame in queue. Current gap size: {self.curr_gap}") + """ + if len(self.track_queue) == 0: + if frame_to_track.has_instances(): + if self.verbose: + warnings.warn( + f"Initializing track on clip ind {batch_idx} frame {frame_to_track.frame_id.item()}" + ) - else: - - if not frames[batch_idx].has_instances(): #Check if there are detections. If there are skip and increment gap count - - for i,instance in enumerate(frames[batch_idx].instances): + for i, instance in enumerate(frames[batch_idx].instances): instance.pred_track_id = i - self.curr_gap += 1 - - if self.verbose: warnings.warn(f"No detections in frame {batch_idx}, {frames[batch_idx].frame_id.item()}. Skipping frame in queue. Current gap size: {self.curr_gap}") - - - else: #detections found. Track and reset gap counter - self.curr_gap = 0 - - instances_to_track = (list(self.track_queue) + [frames[batch_idx]])[-window_size:] - - if len(self.track_queue) == self.track_queue.maxlen: - tracked_frame = self.track_queue.pop() - tracked_frame["tracked"] = True - - self.track_queue.append(frames[batch_idx]) - - query_ind = min(window_size - 1, len(instances_to_track) - 1) - - frames[batch_idx], id_count = self._run_global_tracker( + else: + if ( + frame_to_track.has_instances() + ): # Check if there are detections. If there are skip and increment gap count + frames_to_track = tracked_frames + [ + frame_to_track + ] # better var name? + + query_ind = len(frames_to_track) - 1 + + frame_to_track = self._run_global_tracker( model, - instances_to_track, - query_frame=query_ind, - id_count=id_count, - overlap_thresh=self.overlap_thresh, - mult_thresh=self.mult_thresh, + frames_to_track, + query_ind=query_ind, ) - - if self.curr_gap == self.max_gap: #Check if we've reached the max gap size and reset tracks. - - if self.verbose: warnings.warn(f"Number of consecutive frames with missing detections has exceeded threshold of {self.max_gap}!") - - self.track_queue.clear() - self.curr_gap = 0 - """ - # If first frame. - if frame_id == 0: - instances[0]["pred_track_ids"] = torch.arange( - 0, len(instances[0]["bboxes"])) - id_count = len(instances[0]["bboxes"]) + if frame_to_track.has_instances(): + self.track_queue.add_frame(frame_to_track) else: - win_st = max(0, frame_id + 1 - window_size) - win_ed = frame_id + 1 - instances[win_st: win_ed], id_count = self._run_global_tracker( - instances[win_st: win_ed], - query_frame=min(window_size - 1, frame_id), - id_count=id_count, - overlap_thresh=self.overlap_thresh, - mult_thresh=self.mult_thresh) - """ + self.track_queue.increment_gaps([]) - # If features are out of window, set to none. - # if frame_id - window_size >= 0: - # instances[frame_id - window_size]["features"] = None - - # TODO: Insert postprocessing. - # for frame in instances: - # if "tracked" in frame.keys(): - # frame['features'] = frame['features'].cpu() - self.id_count = id_count + frames[batch_idx] = frame_to_track return frames - def _run_global_tracker(self, model: GlobalTrackingTransformer, frames, query_frame, id_count, overlap_thresh, mult_thresh): + def _run_global_tracker( + self, model: GlobalTrackingTransformer, frames: list[Frame], query_ind: int + ) -> Frame: """Run_global_tracker performs the actual tracking. Uses Hungarian algorithm to do track assigning. @@ -255,17 +198,10 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, frames, query_fr Args: model: the pretrained GlobalTrackingTransformer to be used for inference frames: A list of Frames containing reid features. See `biogtr.data_structures` for more info. - query_frame: An integer for the query frame within the window of instances. - id_count: The count of total identities so far. - overlap_thresh: A float number between 0 and 1 specifying how much - overlap is necessary for assigning a new instance to an existing identity. - mult_thresh: A boolean for whether or not multiple thresholds should be used. - This is not functional as of now. + query_ind: An integer for the query frame within the window of instances. Returns: - frames: The exact list of frames as before but with assigned track ids - and new track ids for the query frame. Refer to the example for the structure. - id_count: An integer for the updated identity count so far. + query_frame: The query frame now populated with the pred_track_ids. """ # *: each item in frames is a frame in the window. So it follows # that each frame in the window has * detected instances. @@ -281,11 +217,19 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, frames, query_fr # Number of instances in each frame of the window. # E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window. - + _ = model.eval() + query_frame = frames[query_ind] instances_per_frame = [frame.num_detected for frame in frames] - total_instances, window_size = sum(instances_per_frame), len(instances_per_frame) # Number of instances in window; length of window. + total_instances, window_size = sum(instances_per_frame), len( + instances_per_frame + ) # Number of instances in window; length of window. + + overlap_thresh = self.overlap_thresh + mult_thresh = self.mult_thresh + n_traj = self.track_queue.n_tracks + reid_features = torch.cat([frame.get_features() for frame in frames], dim=0)[ None ] # (1, total_instances, D=512) @@ -294,45 +238,73 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, frames, query_fr with torch.no_grad(): asso_output, embed = model(frames, query_frame=query_frame) # if model.transformer.return_embedding: - # frames[query_frame].embeddings = embed TODO add embedding to Instance Object + # query_frame.embeddings = embed TODO add embedding to Instance Object # if query_frame == 1: # print(asso_output) - asso_output = asso_output[-1].split(instances_per_frame, dim=1) # (window_size, n_query, N_i) - asso_output = model_utils.softmax_asso(asso_output) # (window_size, n_query, N_i) + asso_output = asso_output[-1].split( + instances_per_frame, dim=1 + ) # (window_size, n_query, N_i) + asso_output = model_utils.softmax_asso( + asso_output + ) # (window_size, n_query, N_i) asso_output = torch.cat(asso_output, dim=1).cpu() # (n_query, total_instances) try: - n_query = frames[query_frame].num_detected # Number of instances in the current/query frame. + n_query = ( + query_frame.num_detected + ) # Number of instances in the current/query frame. except Exception as e: print(len(frames), query_frame, frames[-1]) - raise(e) + raise (e) n_nonquery = ( total_instances - n_query ) # Number of instances in the window not including the current/query frame. - + try: instance_ids = torch.cat( - [x.get_pred_track_ids() for batch_idx, x in enumerate(frames) if batch_idx != query_frame], dim=0 + [ + x.get_pred_track_ids() + for batch_idx, x in enumerate(frames) + if batch_idx != query_frame + ], + dim=0, ).view( n_nonquery ) # (n_nonquery,) except Exception as e: - print(frames) - raise(e) + print( + [ + [instance.pred_track_id.device for instance in frame.instances] + for frame in frames + ] + ) + raise (e) - query_inds = [x for x in range(sum(instances_per_frame[:query_frame]), sum(instances_per_frame[: query_frame + 1]))] + query_inds = [ + x + for x in range( + sum(instances_per_frame[:query_frame]), + sum(instances_per_frame[: query_frame + 1]), + ) + ] nonquery_inds = [i for i in range(total_instances) if i not in query_inds] asso_nonquery = asso_output[:, nonquery_inds] # (n_query, n_nonquery) pred_boxes, _ = model_utils.get_boxes_times(frames) query_boxes = pred_boxes[query_inds] # n_k x 4 - nonquery_boxes = pred_boxes[nonquery_inds] #n_nonquery x 4 + nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4 # TODO: Insert postprocessing. - unique_ids = torch.unique(instance_ids) # (n_nonquery,) - n_traj = len(unique_ids) # Number of existing tracks. - id_inds = (unique_ids[None, :] == instance_ids[:, None]).float() # (n_nonquery, n_traj) + unique_ids = torch.tensor( + [self.track_queue.tracks], device=instance_ids.device + ).view( + n_traj + ) # (n_nonquery,) + + id_inds = ( + unique_ids[None, :] == instance_ids[:, None] + ).float() # (n_nonquery, n_traj) ################################################################################ @@ -352,7 +324,7 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, frames, query_fr decay_time_traj_score.index.name = "Current Frame Instances" decay_time_traj_score.columns.name = "Unique IDs" - frames[query_frame].add_traj_score("decay_time", decay_time_traj_score) + query_frame.add_traj_score("decay_time", decay_time_traj_score) ################################################################################ # with iou -> combining with location in tracker, they set to True @@ -404,12 +376,12 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, frames, query_fr for i in range(n_query): if track_ids[i] < 0: - track_ids[i] = id_count - id_count += 1 + track_ids[i] = n_traj + n_traj += 1 - frames[query_frame].matches = (match_i, match_j) + query_frame.matches = (match_i, match_j) - for instance, track_id in zip(frames[query_frame].instances, track_ids): + for instance, track_id in zip(query_frame.instances, track_ids): instance.pred_track_id = track_id final_traj_score = pd.DataFrame( @@ -418,7 +390,6 @@ def _run_global_tracker(self, model: GlobalTrackingTransformer, frames, query_fr final_traj_score.index.name = "Current Frame Instances" final_traj_score.columns.name = "Unique IDs" - frames[query_frame].add_traj_score("final", final_traj_score) - self.track_queue.append(frames[query_frame]) + query_frame.add_traj_score("final", final_traj_score) - return frames[query_frame], id_count + return query_frame From 02051d365e39cf0edfe452f447cf23db0c904427 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 9 Nov 2023 19:43:07 -0800 Subject: [PATCH 06/40] implement track_local_queues for tracking inference --- biogtr/inference/track_queue.py | 29 +++++++++++++---------------- biogtr/inference/tracker.py | 6 +++--- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/biogtr/inference/track_queue.py b/biogtr/inference/track_queue.py index d8eb9a9c..78576479 100644 --- a/biogtr/inference/track_queue.py +++ b/biogtr/inference/track_queue.py @@ -7,6 +7,7 @@ class TrackQueue: """Class handling track local queue system for sliding window. + Each trajectory has its own deque based queue of size `window_size - 1`. Elements of the queue are Instance objects that have already been tracked and will be compared against later frames for assignment. @@ -32,7 +33,7 @@ def __init__(self, window_size: int, max_gap: int = 1, verbose: bool = False): self._verbose = verbose def __len__(self): - """Length of the queue. + """Get length of the queue. Returns: The total number of instances in every sub-queue. @@ -40,7 +41,7 @@ def __len__(self): return sum([len(queue) for queue in self._queues.values()]) def __repr__(self): - """The string representation of the TrackQueue. + """Return the string representation of the TrackQueue. Returns: The string representation of the current state of the queue. @@ -58,8 +59,7 @@ def __repr__(self): @property def window_size(self) -> int: - """The maximum number of instances allowed in a sub-queue to be - compared against. + """The maximum number of instances allowed in a sub-queue to be compared against. Returns: An int representing The maximum number of instances allowed in a @@ -69,7 +69,7 @@ def window_size(self) -> int: @window_size.setter def window_size(self, window_size: int) -> None: - """Function for setting the window size of the queue. + """Set the window size of the queue. Args: window_size: An int representing The maximum number of instances @@ -79,8 +79,7 @@ def window_size(self, window_size: int) -> None: @property def max_gap(self) -> int: - """The maximum number of consecutive frames an trajectory can fail to - appear before termination. + """The maximum number of consecutive frames an trajectory can fail to appear before termination. Returns: An int representing the maximum number of consecutive frames an trajectory can fail to @@ -90,7 +89,7 @@ def max_gap(self) -> int: @max_gap.setter def max_gap(self, max_gap: int) -> None: - """Function for setting the max consecutive frame gap allowed for a trajectory. + """Set the max consecutive frame gap allowed for a trajectory. Args: max_gap: An int representing the maximum number of consecutive frames an trajectory can fail to @@ -109,7 +108,7 @@ def curr_track(self) -> int: @curr_track.setter def curr_track(self, curr_track: int) -> None: - """Function for setting the newest *created* trajectory in the queue. + """Set the newest *created* trajectory in the queue. Args: curr_track: The latest *created* trajectory in the queue. @@ -136,8 +135,7 @@ def tracks(self) -> list: @property def verbose(self) -> bool: - """Whether or not to print outputs along operations. - Mostly used for debugging. + """Indicate whether or not to print outputs along operations. Mostly used for debugging. Returns: A boolean representing whether or not printing is turned on. @@ -146,7 +144,7 @@ def verbose(self) -> bool: @verbose.setter def verbose(self, verbose: bool) -> None: - """Function for turning on/off printing. + """Turn on/off printing. Args: verbose: A boolean representing whether printing should be on or off. @@ -154,7 +152,7 @@ def verbose(self, verbose: bool) -> None: self._verbose = verbose def end_tracks(self, track_id=None): - """Function for terminating tracks and removing them from the queue. + """Terminate tracks and removing them from the queue. Args: track_id: The index of the trajectory to be ended and removed. @@ -177,7 +175,7 @@ def end_tracks(self, track_id=None): return True def add_frame(self, frame: Frame) -> None: - """Function for adding frames to the queue. + """Add frames to the queue. Each instance from the frame is added to the queue according to its pred_track_id. If the corresponding trajectory is not already in the queue then create a new queue for the track. @@ -250,8 +248,7 @@ def collate_tracks( return [frames[frame].to(device) for frame in sorted(frames.keys())] def increment_gaps(self, pred_track_ids: list[int]) -> dict[int, bool]: - """Function for keeping track of number of consecutive frames each - trajectory has been missing from the queue. + """Keep track of number of consecutive frames each trajectory has been missing from the queue. If a trajectory has exceeded the `max_gap` then terminate the track and remove it from the queue. diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index e7a2668a..2cfdf31b 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -54,7 +54,7 @@ def __init__( self.verbose = verbose def __call__(self, model: GlobalTrackingTransformer, frames: list[Frame]): - """Wrapper around `track` to enable `tracker()` instead of `tracker.track()`. + """Wrap around `track` to enable `tracker()` instead of `tracker.track()`. Args: model: the pretrained GlobalTrackingTransformer to be used for inference @@ -121,7 +121,7 @@ def track(self, model: GlobalTrackingTransformer, frames: list[dict]): def sliding_inference( self, model: GlobalTrackingTransformer, frames: list[Frame], window_size: int ): - """Performs sliding inference on the input video (instances) with a given window size. + """Perform sliding inference on the input video (instances) with a given window size. Args: model: the pretrained GlobalTrackingTransformer to be used for inference @@ -191,7 +191,7 @@ def sliding_inference( def _run_global_tracker( self, model: GlobalTrackingTransformer, frames: list[Frame], query_ind: int ) -> Frame: - """Run_global_tracker performs the actual tracking. + """Run global tracker performs the actual tracking. Uses Hungarian algorithm to do track assigning. From d19bdf40b62a03e651eece4e4a4554a7ced681b5 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 9 Nov 2023 19:43:35 -0800 Subject: [PATCH 07/40] add tests for data_structures fix some bugs with data structures --- biogtr/data_structures.py | 139 +++++++++++------------ tests/test_data_structures.py | 204 ++++++++++++++++++++++++++++++++++ 2 files changed, 274 insertions(+), 69 deletions(-) create mode 100644 tests/test_data_structures.py diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index c104b71e..b397770f 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -3,6 +3,7 @@ from numpy.typing import ArrayLike from typing import Union, List + class Instance: """Class representing a single instance to be tracked.""" @@ -65,7 +66,7 @@ def __init__( self.to(self._device) def __repr__(self) -> str: - """String representation of the Instance.""" + """Return string representation of the Instance.""" return ( "Instance(" f"gt_track_id={self._gt_track_id.item()}, " @@ -78,7 +79,7 @@ def __repr__(self) -> str: ) def to(self, map_location): - """Move instance to different device or change dtype. (See `torch.to` for more info) + """Move instance to different device or change dtype. (See `torch.to` for more info). Args: map_location: Either the device or dtype for the instance to be moved. @@ -104,7 +105,7 @@ def device(self) -> str: @device.setter def device(self, device) -> None: - """Setter for the device property. + """Set for the device property. Args: device: The str representation of the device. @@ -122,7 +123,7 @@ def gt_track_id(self) -> torch.Tensor: @gt_track_id.setter def gt_track_id(self, track: int): - """Function to set the instance ground-truth track id. + """Set the instance ground-truth track id. Args: track: An int representing the ground-truth track id. @@ -133,7 +134,7 @@ def gt_track_id(self, track: int): self._gt_track_id = torch.tensor([]) def has_gt_track_id(self) -> bool: - """Function for determining if instance has a gt track assignment. + """Determine if instance has a gt track assignment. Returns: True if the gt track id is set, otherwise False. @@ -154,8 +155,8 @@ def pred_track_id(self) -> torch.Tensor: @pred_track_id.setter def pred_track_id(self, track: int) -> None: - """Function to set predicted track id. - + """Set predicted track id. + Args: track: an int representing the predicted track id. """ @@ -165,22 +166,20 @@ def pred_track_id(self, track: int) -> None: self._pred_track_id = torch.tensor([]) def has_pred_track_id(self) -> bool: - """Function to determine whether instance has predicted track id - + """Determine whether instance has predicted track id. + Returns: True if instance has a pred track id, False otherwise. - Note that `-1` represents no assigned pred_track_id while - `[]` represents assigned track id of empty instance. """ - if self._pred_track_id == -1: + if self._pred_track_id.item() == -1 or self._pred_track_id.shape[0] == 0: return False else: return True @property def bbox(self) -> torch.Tensor: - """The bounding box coordinates of the instance in the original frame - + """The bounding box coordinates of the instance in the original frame. + Returns: A (1,4) tensor containing the bounding box coordinates. """ @@ -188,7 +187,7 @@ def bbox(self) -> torch.Tensor: @bbox.setter def bbox(self, bbox: ArrayLike) -> None: - """Function to set the instance bounding box. + """Set the instance bounding box. Args: bbox: an arraylike object containing the bounding box coordinates. @@ -205,8 +204,8 @@ def bbox(self, bbox: ArrayLike) -> None: self._bbox = self._bbox.unsqueeze(0) def has_bbox(self) -> bool: - """Function for determining if the instance has a bbox. - + """Determine if the instance has a bbox. + Returns: True if the instance has a bounding box, false otherwise. """ @@ -226,7 +225,7 @@ def crop(self) -> torch.Tensor: @crop.setter def crop(self, crop: ArrayLike) -> None: - """Function to set the crop of the instance. + """Set the crop of the instance. Args: an arraylike object containing the cropped image of the centered instance. @@ -245,7 +244,7 @@ def crop(self, crop: ArrayLike) -> None: self._crop = self._crop.unsqueeze(0) def has_crop(self) -> bool: - """Function to determine if the instance has a crop. + """Determine if the instance has a crop. Returns: True if the instance has an image otherwise False. @@ -257,8 +256,8 @@ def has_crop(self) -> bool: @property def features(self) -> torch.Tensor: - """ReID feature vector from backbone model to be used as input to transformer. - + """Re-ID feature vector from backbone model to be used as input to transformer. + Returns: a (1, d) tensor containing the reid feature vector. """ @@ -266,7 +265,7 @@ def features(self) -> torch.Tensor: @features.setter def features(self, features: ArrayLike) -> None: - """Function to set the reid feature vector of the instance. + """Set the reid feature vector of the instance. Args: features: a (1,d) array like object containing the reid features for the instance. @@ -274,7 +273,7 @@ def features(self, features: ArrayLike) -> None: if features is None or len(features) == 0: self._features = torch.tensor([]) - if not isinstance(features, torch.Tensor): + elif not isinstance(features, torch.Tensor): self._features = torch.tensor(features) else: self._features = features @@ -283,8 +282,8 @@ def features(self, features: ArrayLike) -> None: self._features = self._features.unsqueeze(0) def has_features(self) -> bool: - """Function for determining if the instance has computed reid features. - + """Determine if the instance has computed reid features. + Returns: True if the instance has reid features, False otherwise. """ @@ -296,6 +295,7 @@ def has_features(self) -> bool: class Frame: """Data structure containing metadata for a single frame of a video.""" + def __init__( self, video_id: int, @@ -308,7 +308,7 @@ def __init__( device=None, ): """Initialize Frame. - + Args: video_id: The video index in the dataset. frame_id: The index of the frame in a video. @@ -336,7 +336,9 @@ def __init__( self._asso_output = asso_output self._matches = matches - if isinstance(traj_score, dict): + if traj_score is None: + self._traj_score = {} + elif isinstance(traj_score, dict): self._traj_score = traj_score else: self._traj_score = {"initial": traj_score} @@ -345,7 +347,7 @@ def __init__( self.to(device) def __repr__(self) -> str: - """String representation of the Frame. + """Return String representation of the Frame. Returns: The string representation of the frame. @@ -365,7 +367,7 @@ def __repr__(self) -> str: ) def to(self, map_location: str): - """Function for moving frame to different device or dtype (See `torch.to` for more info). + """Move frame to different device or dtype (See `torch.to` for more info). Args: map_location: A string representing the device to move to. @@ -377,10 +379,10 @@ def to(self, map_location: str): self._img_shape = self._img_shape.to(map_location) if isinstance(self._asso_output, torch.Tensor): - self._asso_output = asso_output.to(map_location) + self._asso_output = self._asso_output.to(map_location) if isinstance(self._matches, torch.Tensor): - self._matches = matches.to(map_location) + self._matches = self._matches.to(map_location) for key, val in self._traj_score.items(): if isinstance(val, torch.Tensor): @@ -403,7 +405,7 @@ def device(self) -> str: @device.setter def device(self, device: str) -> None: - """Function to set the device. + """Set the device. Note: Do not set `frame.device = device` normally. Use `frame.to(device)` instead. @@ -412,7 +414,6 @@ def device(self, device: str) -> None: """ self._device = device - @property def video_id(self) -> torch.Tensor: """ @@ -425,8 +426,8 @@ def video_id(self) -> torch.Tensor: @video_id.setter def video_id(self, video_id: int) -> None: - """Function for setting the video index. - + """Set the video index. + Note: Generally the video_id should be immutable after initialization. Args: @@ -437,7 +438,7 @@ def video_id(self, video_id: int) -> None: @property def frame_id(self) -> torch.Tensor: """The index of the frame in a full video. - + Returns: A torch tensor containing the index of the frame in the video. """ @@ -445,8 +446,8 @@ def frame_id(self) -> torch.Tensor: @frame_id.setter def frame_id(self, frame_id: int) -> None: - """Function for setting the frame index of the frame. - + """Set the frame index of the frame. + Note: The frame_id should generally be immutable after initialization. Args: @@ -457,7 +458,7 @@ def frame_id(self, frame_id: int) -> None: @property def img_shape(self) -> torch.Tensor: """The shape of the pre-cropped frame. - + Returns: A torch tensor containing the shape of the frame. Should generally be (c, h, w) """ @@ -465,7 +466,7 @@ def img_shape(self) -> torch.Tensor: @img_shape.setter def img_shape(self, img_shape: ArrayLike) -> None: - """Function for setting the shape of the frame image + """Set the shape of the frame image. Note: the img_shape should generally be immutable after initialization. @@ -488,7 +489,7 @@ def instances(self) -> List[Instance]: @instances.setter def instances(self, instances: List[Instance]) -> None: - """Function for setting the frame's instance + """Set the frame's instance. Args: instances: A list of Instances that appear in the frame. @@ -496,7 +497,7 @@ def instances(self, instances: List[Instance]) -> None: self._instances = instances def has_instances(self) -> bool: - """Function for determining whether there are instances in the frame. + """Determine whether there are instances in the frame. Returns: True if there are instances in the frame, otherwise False. @@ -517,14 +518,14 @@ def num_detected(self) -> int: @property def asso_output(self) -> ArrayLike: """The association matrix between instances outputed directly by transformer. - - Returns: + + Returns: An arraylike (n_query, n_nonquery) association matrix between instances. """ return self._asso_output def has_asso_output(self) -> bool: - """Function for determining whether the frame has an association matrix computed. + """Determine whether the frame has an association matrix computed. Returns: True if the frame has an association matrix otherwise, False. @@ -535,10 +536,11 @@ def has_asso_output(self) -> bool: @asso_output.setter def asso_output(self, asso_output: ArrayLike) -> None: - """Function for setting the association matrix of a frame. + """Set the association matrix of a frame. Args: - asso_output: An arraylike (n_query, n_nonquery) association matrix between instances.""" + asso_output: An arraylike (n_query, n_nonquery) association matrix between instances. + """ self._asso_output = asso_output @property @@ -552,7 +554,7 @@ def matches(self) -> tuple: @matches.setter def matches(self, matches: tuple) -> None: - """Function for setting the frame matches + """Set the frame matches. Args: matches: A tuple containing the instance idx and trajectory idx for the matched instance. @@ -560,7 +562,7 @@ def matches(self, matches: tuple) -> None: self._matches = matches def has_matches(self) -> bool: - """Function for whether or not matches have been computed for frame. + """Check whether or not matches have been computed for frame. Returns: True if frame contains matches otherwise False. @@ -570,11 +572,10 @@ def has_matches(self) -> bool: return False def get_traj_score(self, key=None) -> Union[dict, ArrayLike, None]: - """Dictionary containing association matrix between instances and - trajectories along postprocessing pipeline. + """Get dictionary containing association matrix between instances and trajectories along postprocessing pipeline. Args: - key: The key of the trajectory score to be accessed. + key: The key of the trajectory score to be accessed. Can be one of {None, 'initial', 'decay_time', 'max_center_dist', 'iou', 'final'} Returns: - dictionary containing all trajectory scores if key is None @@ -587,12 +588,12 @@ def get_traj_score(self, key=None) -> Union[dict, ArrayLike, None]: try: return self._traj_score[key] except KeyError as e: - print("Could not access {key} traj_score due to {e}") + print(f"Could not access {key} traj_score due to {e}") return None def add_traj_score(self, key, traj_score: ArrayLike) -> None: - """Function for adding trajectory score to dictionary - + """Add trajectory score to dictionary. + Args: key: key associated with traj score to be used in dictionary traj_score: association matrix between instances and trajectories @@ -600,8 +601,8 @@ def add_traj_score(self, key, traj_score: ArrayLike) -> None: self._traj_score[key] = traj_score def has_traj_score(self) -> bool: - """Function for checking if any trajectory association matrix has been saved - + """Check if any trajectory association matrix has been saved. + Returns: True there is at least one association matrix otherwise, false. """ @@ -610,8 +611,8 @@ def has_traj_score(self) -> bool: return True def has_gt_track_ids(self) -> bool: - """Function to check if any of frames instances has a gt track id - + """Check if any of frames instances has a gt track id. + Returns: True if at least 1 instance has a gt track id otherwise False. """ @@ -620,7 +621,7 @@ def has_gt_track_ids(self) -> bool: return False def get_gt_track_ids(self) -> torch.Tensor: - """Function to get the gt track ids of all instances in the frame + """Get the gt track ids of all instances in the frame. Returns: an (N,) shaped tensor with the gt track ids of each instance in the frame. @@ -630,7 +631,7 @@ def get_gt_track_ids(self) -> torch.Tensor: return torch.cat([instance.gt_track_id for instance in self.instances]) def has_pred_track_ids(self) -> bool: - """Function to check if any of frames instances has a pred track id + """Check if any of frames instances has a pred track id. Returns: True if at least 1 instance has a pred track id otherwise False. @@ -640,7 +641,7 @@ def has_pred_track_ids(self) -> bool: return False def get_pred_track_ids(self) -> torch.Tensor: - """Function to get the pred track ids of all instances in the frame + """Get the pred track ids of all instances in the frame. Returns: an (N,) shaped tensor with the pred track ids of each instance in the frame. @@ -650,7 +651,7 @@ def get_pred_track_ids(self) -> torch.Tensor: return torch.cat([instance.pred_track_id for instance in self.instances]) def has_bboxes(self) -> bool: - """Function to check if any of frames instances has a bounding box + """Check if any of frames instances has a bounding box. Returns: True if at least 1 instance has a bounding box otherwise False. @@ -660,17 +661,17 @@ def has_bboxes(self) -> bool: return False def get_bboxes(self) -> torch.Tensor: - """Function to get the bounding boxes of all instances in the frame + """Get the bounding boxes of all instances in the frame. Returns: an (N,4) shaped tensor with bounding boxes of each instance in the frame. """ if not self.has_instances(): - return torch.empty(0,4) + return torch.empty(0, 4) return torch.cat([instance.bbox for instance in self.instances], dim=0) def has_crops(self) -> bool: - """Function to check if any of frames instances has a crop + """Check if any of frames instances has a crop. Returns: True if at least 1 instance has a crop otherwise False. @@ -680,7 +681,7 @@ def has_crops(self) -> bool: return False def get_crops(self) -> torch.Tensor: - """Function to get the crops of all instances in the frame + """Get the crops of all instances in the frame. Returns: an (N, C, H, W) shaped tensor with crops of each instance in the frame. @@ -690,7 +691,7 @@ def get_crops(self) -> torch.Tensor: return torch.cat([instance.crop for instance in self.instances], dim=0) def has_features(self): - """Function to check if any of frames instances has reid features already computed + """Check if any of frames instances has reid features already computed. Returns: True if at least 1 instance have reid features otherwise False. @@ -700,7 +701,7 @@ def has_features(self): return False def get_features(self): - """Function to get the reid feature vectors of all instances in the frame + """Get the reid feature vectors of all instances in the frame. Returns: an (N, D) shaped tensor with reid feature vectors of each instance in the frame. diff --git a/tests/test_data_structures.py b/tests/test_data_structures.py new file mode 100644 index 00000000..d44bf9db --- /dev/null +++ b/tests/test_data_structures.py @@ -0,0 +1,204 @@ +"""Tests for Instance, Frame, and TrackQueue Object""" +from biogtr.data_structures import Instance, Frame +from biogtr.inference.track_queue import TrackQueue +import torch + + +def test_instance(): + """Test Instance object logic.""" + + gt_track_id = 0 + pred_track_id = 0 + bbox = torch.randn((1, 4)) + crop = torch.randn((1, 3, 128, 128)) + features = torch.randn((1, 64)) + + instance = Instance( + gt_track_id=gt_track_id, + pred_track_id=pred_track_id, + bbox=bbox, + crop=crop, + features=features, + ) + + assert instance.has_gt_track_id() + assert instance.gt_track_id.item() == gt_track_id + assert instance.has_pred_track_id() + assert instance.pred_track_id.item() == pred_track_id + assert instance.has_bbox() + assert torch.equal(instance.bbox, bbox) + assert instance.has_features() + assert torch.equal(instance.features, features) + + instance.gt_track_id = 1 + instance.pred_track_id = 1 + instance.bbox = torch.randn((1, 4)) + instance.crop = torch.randn((1, 3, 128, 128)) + instance.features = torch.randn((1, 64)) + + assert instance.has_gt_track_id() + assert instance.gt_track_id.item() != gt_track_id + assert instance.has_pred_track_id() + assert instance.pred_track_id.item() != pred_track_id + assert instance.has_bbox() + assert not torch.equal(instance.bbox, bbox) + assert instance.has_features() + assert not torch.equal(instance.features, features) + + instance.gt_track_id = None + instance.pred_track_id = -1 + instance.bbox = None + instance.crop = None + instance.features = None + + assert not instance.has_gt_track_id() + assert instance.gt_track_id.shape[0] == 0 + assert not instance.has_pred_track_id() + assert instance.pred_track_id.item() != pred_track_id + assert not instance.has_bbox() + assert not torch.equal(instance.bbox, bbox) + assert not instance.has_features() + assert not torch.equal(instance.features, features) + + +def test_frame(): + n_detected = 2 + n_traj = 3 + video_id = 0 + frame_id = 0 + img_shape = torch.tensor([3, 1024, 1024]) + asso_output = torch.randn(n_detected, 16) + traj_score = torch.randn(n_detected, n_traj) + matches = ([0, 1], [0, 1]) + + instances = [] + for i in range(n_detected): + instances.append( + Instance( + gt_track_id=i, + pred_track_id=i, + bbox=torch.randn(1, 4), + crop=torch.randn(1, 3, 64, 64), + features=torch.randn(1, 64), + ) + ) + frame = Frame( + video_id=video_id, frame_id=frame_id, img_shape=img_shape, instances=instances + ) + + assert frame.video_id.item() == video_id + assert frame.frame_id.item() == frame_id + assert torch.equal(frame.img_shape, img_shape) + assert frame.num_detected == n_detected + assert frame.has_instances() + assert len(frame.instances) == n_detected + assert frame.has_gt_track_ids() + assert len(frame.get_gt_track_ids()) == n_detected + assert frame.has_pred_track_ids() + assert len(frame.get_pred_track_ids()) == n_detected + assert not frame.has_matches() + assert not frame.has_asso_output() + assert not frame.has_traj_score() + + frame.asso_output = asso_output + frame.add_traj_score("initial", traj_score) + frame.matches = matches + + assert frame.has_matches() + assert frame.matches == matches + assert frame.has_asso_output() + assert torch.equal(frame.asso_output, asso_output) + assert frame.has_traj_score() + assert torch.equal(frame.get_traj_score("initial"), traj_score) + + frame.instances = [] + + assert frame.video_id.item() == video_id + assert frame.num_detected == 0 + assert not frame.has_instances() + assert len(frame.instances) == 0 + assert not frame.has_gt_track_ids() + assert not len(frame.get_gt_track_ids()) + assert not frame.has_pred_track_ids() + assert len(frame.get_pred_track_ids()) == 0 + assert frame.has_matches() + assert frame.has_asso_output() + assert frame.has_traj_score() + + +def test_track_queue(): + window_size = 8 + max_gap = 10 + img_shape = (3, 1024, 1024) + n_instances_per_frame = [2] * window_size + + frames = [] + instances_per_frame = [] + + tq = TrackQueue(window_size, max_gap) + for i in range(window_size): + instances = [] + for j in range(n_instances_per_frame[i]): + instances.append(Instance(gt_track_id=j, pred_track_id=j)) + instances_per_frame.append(instances) + frame = Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) + frames.append(frame) + + tq.add_frame(frame) + + assert len(tq) == sum(n_instances_per_frame[1:]) + assert tq.n_tracks == max(n_instances_per_frame) + assert tq.tracks == [i for i in range(max(n_instances_per_frame))] + assert len(tq.collate_tracks()) == window_size - 1 + assert all([gap == 0 for gap in tq._curr_gap.values()]) + assert tq.curr_track == max(n_instances_per_frame) - 1 + + tq.add_frame( + Frame( + video_id=0, + frame_id=window_size + 1, + img_shape=img_shape, + instances=[Instance(gt_track_id=0, pred_track_id=0)], + ) + ) + + assert len(tq._queues[0]) == window_size - 1 + assert tq._curr_gap[0] == 0 + assert tq._curr_gap[max(n_instances_per_frame) - 1] == 1 + + tq.add_frame( + Frame( + video_id=0, + frame_id=window_size + 1, + img_shape=img_shape, + instances=[ + Instance(gt_track_id=1, pred_track_id=1), + Instance( + gt_track_id=max(n_instances_per_frame), + pred_track_id=max(n_instances_per_frame), + ), + ], + ) + ) + + assert len(tq._queues[max(n_instances_per_frame)]) == 1 + assert tq._curr_gap[1] == 0 + assert tq._curr_gap[0] == 1 + + for i in range(max_gap): + tq.add_frame( + Frame( + video_id=0, + frame_id=window_size + i + 1, + img_shape=img_shape, + instances=[Instance(gt_track_id=0, pred_track_id=0)], + ) + ) + + assert tq.n_tracks == 1 + assert tq.curr_track == max(n_instances_per_frame) + assert 0 in tq._queues.keys() + + tq.end_tracks() + + assert len(tq) == 0 From 25dd3bfa7275d280ec570f574e071c991d27b071 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 9 Nov 2023 19:44:29 -0800 Subject: [PATCH 08/40] fix formatting + doc strings --- biogtr/config.py | 31 ++++---- biogtr/datasets/base_dataset.py | 21 ++--- biogtr/datasets/cell_tracking_dataset.py | 36 +++++---- biogtr/datasets/data_utils.py | 41 +++++----- biogtr/datasets/eval_dataset.py | 50 +++++++----- biogtr/datasets/microscopy_dataset.py | 43 +++++----- biogtr/datasets/sleap_dataset.py | 82 +++++++++++--------- biogtr/datasets/tracking_dataset.py | 13 ++-- biogtr/inference/__init__.py | 1 + biogtr/inference/boxes.py | 4 +- biogtr/inference/metrics.py | 77 ++++++++++++------ biogtr/inference/post_processing.py | 11 ++- biogtr/inference/track.py | 15 +++- biogtr/models/attention_head.py | 4 +- biogtr/models/embedding.py | 22 ++++-- biogtr/models/global_tracking_transformer.py | 13 ++-- biogtr/models/gtr_runner.py | 56 ++++++++----- biogtr/models/model_utils.py | 6 +- biogtr/models/transformer.py | 46 +++++++---- biogtr/training/train.py | 11 +-- biogtr/visualize.py | 20 ++--- tests/fixtures/configs.py | 5 +- tests/fixtures/torch.py | 4 +- tests/test_datasets.py | 14 ++-- 24 files changed, 364 insertions(+), 262 deletions(-) create mode 100644 biogtr/inference/__init__.py diff --git a/biogtr/config.py b/biogtr/config.py index df32350c..f19d153f 100644 --- a/biogtr/config.py +++ b/biogtr/config.py @@ -43,7 +43,7 @@ def __repr__(self): return f"Config({self.cfg})" def __str__(self): - """String representation of config class.""" + """Return a string representation of config class.""" return f"Config({self.cfg})" def set_hparams(self, hparams: dict) -> bool: @@ -92,20 +92,21 @@ def get_tracker_cfg(self) -> dict: def get_gtr_runner(self): """Get lightning module for training, validation, and inference.""" - tracker_params = self.cfg.tracker optimizer_params = self.cfg.optimizer scheduler_params = self.cfg.scheduler loss_params = self.cfg.loss gtr_runner_params = self.cfg.runner - + if self.cfg.model.ckpt_path is not None and self.cfg.model.ckpt_path != "": - model = GTRRunner.load_from_checkpoint(self.cfg.model.ckpt_path, - tracker_cfg = tracker_params, - train_metrics=self.cfg.runner.train_metrics, - val_metrics=self.cfg.runner.val_metrics, - test_metrics=self.cfg.runner.test_metrics) - + model = GTRRunner.load_from_checkpoint( + self.cfg.model.ckpt_path, + tracker_cfg=tracker_params, + train_metrics=self.cfg.runner.train_metrics, + val_metrics=self.cfg.runner.val_metrics, + test_metrics=self.cfg.runner.test_metrics, + ) + else: model_params = self.cfg.model model = GTRRunner( @@ -186,13 +187,13 @@ def get_dataloader( torch.multiprocessing.set_sharing_strategy("file_system") else: pin_memory = False - + return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, pin_memory=pin_memory, collate_fn=dataset.no_batching_fn, - **dataloader_params + **dataloader_params, ) def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer: @@ -282,7 +283,7 @@ def get_trainer( callbacks: list[pl.callbacks.Callback], logger: pl.loggers.WandbLogger, devices: int = 1, - accelerator: str = None + accelerator: str = None, ) -> pl.Trainer: """Getter for the lightning trainer. @@ -297,12 +298,12 @@ def get_trainer( A lightning Trainer with specified params """ if "accelerator" not in self.cfg.trainer: - self.set_hparams({'trainer.accelerator': accelerator}) + self.set_hparams({"trainer.accelerator": accelerator}) if "devices" not in self.cfg.trainer: - self.set_hparams({'trainer.devices': devices}) + self.set_hparams({"trainer.devices": devices}) trainer_params = self.cfg.trainer - + return pl.Trainer( callbacks=callbacks, logger=logger, diff --git a/biogtr/datasets/base_dataset.py b/biogtr/datasets/base_dataset.py index 5fd4e432..61b3e00b 100644 --- a/biogtr/datasets/base_dataset.py +++ b/biogtr/datasets/base_dataset.py @@ -1,6 +1,6 @@ """Module containing logic for loading datasets.""" from biogtr.datasets import data_utils -from biogtr.data_structures import Frame, Instance +from biogtr.data_structures import Frame from torch.utils.data import Dataset from typing import List, Union import numpy as np @@ -21,7 +21,7 @@ def __init__( augmentations: dict = None, n_chunks: Union[int, float] = 1.0, seed: int = None, - gt_list: str = None + gt_list: str = None, ): """Initialize Dataset. @@ -72,27 +72,28 @@ def create_chunks(self) -> None: efficiency and data shuffling. To be called by subclass __init__() """ if self.chunk: - self.chunked_frame_idx, self.label_idx = [], [] for i, frame_idx in enumerate(self.frame_idx): frame_idx_split = torch.split(frame_idx, self.clip_length) self.chunked_frame_idx.extend(frame_idx_split) self.label_idx.extend(len(frame_idx_split) * [i]) - + if self.n_chunks > 0 and self.n_chunks <= 1.0: n_chunks = int(self.n_chunks * len(self.chunked_frame_idx)) elif self.n_chunks <= len(self.chunked_frame_idx): n_chunks = int(self.n_chunks) - + else: n_chunks = len(self.chunked_frame_idx) if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx): - sample_idx = np.random.choice(np.arange(len(self.chunked_frame_idx)), n_chunks) + sample_idx = np.random.choice( + np.arange(len(self.chunked_frame_idx)), n_chunks + ) self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx] - + self.label_idx = [self.label_idx[i] for i in sample_idx] else: @@ -126,14 +127,14 @@ def __getitem__(self, idx: int) -> List[Frame]: or the frame. Returns: - A list of `Frame`s in the chunk containing the metadata + instance features. + A list of `Frame`s in the chunk containing the metadata + instance features. """ label_idx, frame_idx = self.get_indices(idx) return self.get_instances(label_idx, frame_idx) def get_indices(self, idx: int): - """Retrieves label and frame indices given batch index. + """Retrieve label and frame indices given batch index. This method should be implemented in any subclass of the BaseDataset. @@ -146,7 +147,7 @@ def get_indices(self, idx: int): raise NotImplementedError("Must be implemented in subclass") def get_instances(self, label_idx: List[int], frame_idx: List[int]): - """Builds chunk of frames. + """Build chunk of frames. This method should be implemented in any subclass of the BaseDataset. diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 605dcfeb..6a421e13 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -4,13 +4,9 @@ from biogtr.datasets.base_dataset import BaseDataset from biogtr.data_structures import Instance, Frame from scipy.ndimage import measurements -from torch.utils.data import Dataset -from torchvision.transforms import functional as tvf from typing import List, Optional, Union import albumentations as A -import glob import numpy as np -import os import pandas as pd import random import torch @@ -31,7 +27,7 @@ def __init__( augmentations: Optional[dict] = None, n_chunks: Union[int, float] = 1.0, seed: int = None, - gt_list: str = None + gt_list: str = None, ): """Initialize CellTrackingDataset. @@ -68,7 +64,7 @@ def __init__( augmentations, n_chunks, seed, - gt_list + gt_list, ) self.videos = raw_images @@ -105,7 +101,7 @@ def __init__( self.create_chunks() def get_indices(self, idx): - """Retrieves label and frame indices given batch index. + """Retrieve label and frame indices given batch index. Args: idx: the index of the batch. @@ -120,7 +116,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram frame_idx: index of the frames Returns: - a list of Frame objects containing frame metadata and Instance Objects. + a list of Frame objects containing frame metadata and Instance Objects. See `biogtr.data_structures` for more info. """ image = self.videos[label_idx] @@ -187,14 +183,22 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram for i in range(len(gt_track_ids)): crop = data_utils.crop_bbox(img, bboxes[i]) - instances.append(Instance(gt_track_id=gt_track_ids[i], - pred_track_id=-1, - bbox=bboxes[i], - crop=crop)) + instances.append( + Instance( + gt_track_id=gt_track_ids[i], + pred_track_id=-1, + bbox=bboxes[i], + crop=crop, + ) + ) - frames.append(Frame(video_id=label_idx, - frame_id=i, - img_shape=img.shape, - instances=instances)) + frames.append( + Frame( + video_id=label_idx, + frame_id=i, + img_shape=img.shape, + instances=instances, + ) + ) return frames diff --git a/biogtr/datasets/data_utils.py b/biogtr/datasets/data_utils.py index 8ee2d19c..76b44972 100644 --- a/biogtr/datasets/data_utils.py +++ b/biogtr/datasets/data_utils.py @@ -67,8 +67,7 @@ def get_bbox(center: ArrayLike, size: Union[int, tuple[int]]) -> torch.Tensor: cx, cy = center[0], center[1] bbox = torch.Tensor( - [-size[-1] // 2 + cy, -size[0] // 2 + cx, - size[-1] // 2 + cy, size[0] // 2 + cx] + [-size[-1] // 2 + cy, -size[0] // 2 + cx, size[-1] // 2 + cy, size[0] // 2 + cx] ) return bbox @@ -107,9 +106,7 @@ def centroid_bbox(points: ArrayLike, anchors: list, crop_size: int) -> torch.Ten return bbox -def pose_bbox( - points: np.ndarray, bbox_size: Union[tuple[int], int] -) -> torch.Tensor: +def pose_bbox(points: np.ndarray, bbox_size: Union[tuple[int], int]) -> torch.Tensor: """Calculate bbox around instance pose. Args: @@ -122,18 +119,24 @@ def pose_bbox( if type(bbox_size) == int: bbox_size = (bbox_size, bbox_size) # print(points) - minx = np.nanmin(points[:,0], axis=-1) - miny = np.nanmin(points[:,-1], axis=-1) + minx = np.nanmin(points[:, 0], axis=-1) + miny = np.nanmin(points[:, -1], axis=-1) minpoints = np.array([minx, miny]).T - - maxx = np.nanmax(points[:,0], axis=-1) - maxy = np.nanmax(points[:,-1], axis=-1) + + maxx = np.nanmax(points[:, 0], axis=-1) + maxy = np.nanmax(points[:, -1], axis=-1) maxpoints = np.array([maxx, maxy]).T - - c = ((minpoints + maxpoints)/2) - - bbox = torch.Tensor([c[-1]-bbox_size[-1]/2, c[0] - bbox_size[0]/2, - c[-1] + bbox_size[-1]/2, c[0] + bbox_size[0]/2]) + + c = (minpoints + maxpoints) / 2 + + bbox = torch.Tensor( + [ + c[-1] - bbox_size[-1] / 2, + c[0] - bbox_size[0] / 2, + c[-1] + bbox_size[-1] / 2, + c[0] + bbox_size[0] / 2, + ] + ) return bbox @@ -210,7 +213,7 @@ def parse_trackmate(data_path: str) -> pd.DataFrame: and centroid x,y coordinates in pixels """ if data_path.endswith(".xml"): - root = et.fromstring(open(xml_path).read()) + root = et.fromstring(open(data_path).read()) objects = [] features = root.find("Model").find("FeatureDeclarations").find("SpotFeatures") @@ -444,7 +447,7 @@ def get_max_padding(height: int, width: int) -> tuple: def view_training_batch( instances: List[Dict[str, List[np.ndarray]]], num_frames: int = 1, cmap=None ) -> None: - """Displays a grid of images from a batch of training instances. + """Display a grid of images from a batch of training instances. Args: instances: A list of training instances, where each instance is a @@ -472,7 +475,9 @@ def view_training_batch( else (axes[i] if num_crops == 1 else axes[i, j]) ) - ax.imshow(data.T) if cmap is None else ax.imshow(data.T, cmap=cmap) + ax.imshow(data.T) if isinstance(cmap, None) else ax.imshow( + data.T, cmap=cmap + ) ax.axis("off") except Exception as e: diff --git a/biogtr/datasets/eval_dataset.py b/biogtr/datasets/eval_dataset.py index e021e9e6..65b655f8 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/biogtr/datasets/eval_dataset.py @@ -1,22 +1,22 @@ -"Module containing wrapper for merging gt and pred datasets for evaluation" -import torch +"""Module containing wrapper for merging gt and pred datasets for evaluation.""" from torch.utils.data import Dataset from biogtr.data_structures import Frame, Instance from typing import List + class EvalDataset(Dataset): - + """Wrapper around gt and predicted dataset.""" + def __init__(self, gt_dataset: Dataset, pred_dataset: Dataset) -> None: - """Initialize EvalDataset - + """Initialize EvalDataset. + Args: gt_dataset: A Dataset object containing ground truth track ids pred_dataset: A dataset object containing predicted track ids """ - self.gt_dataset = gt_dataset self.pred_dataset = pred_dataset - + def __len__(self) -> int: """Get the size of the dataset. @@ -24,7 +24,7 @@ def __len__(self) -> int: the size or the number of chunks in the dataset """ return len(self.gt_dataset) - + def __getitem__(self, idx: int) -> List[Frame]: """Get an element of the dataset. @@ -35,19 +35,29 @@ def __getitem__(self, idx: int) -> List[Frame]: Returns: A list of Frames where frames contain instances w gt and pred track ids + bboxes. """ - gt_batch = self.gt_dataset[i] - pred_batch = self.pred_dataset[i] + gt_batch = self.gt_dataset[idx] + pred_batch = self.pred_dataset[idx] eval_frames = [] for gt_frame, pred_frame in zip(gt_batch, pred_batch): eval_instances = [] - for gt_instance, pred_instance in zip(gt_frame.instances, pred_frame.instances): - eval_instances.append(Instance(gt_track_id=gt_instance.gt_track_id, - pred_track_id=pred_instance.pred_track_id, - bbox=pred_instance.bbox)) - eval_frames.append(Frame(video_id=gt_frame.video_id, - frame_id=gt_frame.frame_id, - img_shape=gt_frame.img_shape, - instances=eval_instances)) - - return eval_frames \ No newline at end of file + for gt_instance, pred_instance in zip( + gt_frame.instances, pred_frame.instances + ): + eval_instances.append( + Instance( + gt_track_id=gt_instance.gt_track_id, + pred_track_id=pred_instance.pred_track_id, + bbox=pred_instance.bbox, + ) + ) + eval_frames.append( + Frame( + video_id=gt_frame.video_id, + frame_id=gt_frame.frame_id, + img_shape=gt_frame.img_shape, + instances=eval_instances, + ) + ) + + return eval_frames diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index 062b6bd7..d063b29c 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -3,8 +3,6 @@ from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset from biogtr.data_structures import Frame, Instance -from torch.utils.data import Dataset -from torchvision.transforms import functional as tvf from typing import Union import albumentations as A import numpy as np @@ -27,7 +25,7 @@ def __init__( mode: str = "Train", augmentations: dict = None, n_chunks: Union[int, float] = 1.0, - seed: int = None + seed: int = None, ): """Initialize MicroscopyDataset. @@ -96,7 +94,7 @@ def __init__( self.frame_idx = [ torch.arange(Image.open(video).n_frames) - if type(video) == str + if isinstance(video, str) else torch.arange(len(video)) for video in self.videos ] @@ -106,7 +104,7 @@ def __init__( self.create_chunks() def get_indices(self, idx): - """Retrieves label and frame indices given batch index. + """Retrieve label and frame indices given batch index. Args: idx: the index of the batch. @@ -128,17 +126,16 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram video = self.videos[label_idx] - if type(video) != list: + if not isinstance(video, list): video = data_utils.LazyTiffStack(self.videos[label_idx]) - frames = [] for i in frame_idx: - instances, gt_track_ids, centroids, bboxes, crops = [], [], [], [], [] + instances, gt_track_ids, centroids = [], [], [] img = ( video.get_section(i) - if type(video) != list + if isinstance(video, list) else np.array(Image.open(video[i])) ) @@ -181,14 +178,22 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram ) crop = data_utils.crop_bbox(img, bbox) - instances.append(Instance(gt_track_id=gt_track_ids[i], - pred_track_id=-1, - bbox=bbox, - crop=crop)) - - frames.append(Frame(video_id=label_idx, - frame_id=i, - img_shape=img.shape, - instances=instances)) - + instances.append( + Instance( + gt_track_id=gt_track_ids[i], + pred_track_id=-1, + bbox=bbox, + crop=crop, + ) + ) + + frames.append( + Frame( + video_id=label_idx, + frame_id=i, + img_shape=img.shape, + instances=instances, + ) + ) + return frames diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index 9474c5a6..768e8213 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -29,7 +29,7 @@ def __init__( augmentations: dict = None, n_chunks: Union[int, float] = 1.0, seed: int = None, - verbose: bool = False + verbose: bool = False, ): """Initialize SleapDataset. @@ -38,7 +38,7 @@ def __init__( video_files: a list of paths to video files padding: amount of padding around object crops crop_size: the size of the object crops - anchor: the name of the anchor keypoint to be used as centroid for cropping. + anchor: the name of the anchor keypoint to be used as centroid for cropping. If unavailable then crop around the midpoint between all visible anchors. chunk: whether or not to chunk the dataset into batches clip_length: the number of frames in each chunk @@ -78,7 +78,7 @@ def __init__( self.n_chunks = n_chunks self.seed = seed self.anchor = anchor.lower() - self.verbose=verbose + self.verbose = verbose # if self.seed is not None: # np.random.seed(self.seed) @@ -100,7 +100,7 @@ def __init__( self.create_chunks() def get_indices(self, idx): - """Retrieves label and frame indices given batch index. + """Retrieve label and frame indices given batch index. Args: idx: the index of the batch. @@ -140,22 +140,21 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict vid_reader = imageio.get_reader(video_name, "ffmpeg") img = vid_reader.get_data(0) - crop_shape = (img.shape[-1], *(self.crop_size + 2 * self.padding,) * 2) frames = [] for i, frame_ind in enumerate(frame_idx): - instances, gt_track_ids, bboxes, crops, shown_poses = [], [], [], [], [] + instances, gt_track_ids, shown_poses = [], [], [] frame_ind = int(frame_ind) - + lf = video[frame_ind] - + try: img = vid_reader.get_data(frame_ind) except IndexError as e: - print(f"Could not read frame {frame_ind} from {video_name}") + print(f"Could not read frame {frame_ind} from {video_name} due to {e}") continue - + for instance in lf: gt_track_ids.append(video.tracks.index(instance.track)) @@ -168,9 +167,14 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict ) ) - shown_poses = [{key.lower(): val for key, val in instance.items() - if not np.isnan(val).any() - } for instance in shown_poses] + shown_poses = [ + { + key.lower(): val + for key, val in instance.items() + if not np.isnan(val).any() + } + for instance in shown_poses + ] # augmentations if self.augmentations is not None: for transform in self.augmentations: @@ -201,34 +205,37 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict for aug_pose_arr, pose_dict in zip(aug_poses, shown_poses) ] - _ = [pose.update(aug_pose) for pose, aug_pose in zip(shown_poses, aug_poses)] + _ = [ + pose.update(aug_pose) + for pose, aug_pose in zip(shown_poses, aug_poses) + ] img = tvf.to_tensor(img) for i in range(len(gt_track_ids)): - pose = shown_poses[i] """Check for anchor""" if self.anchor in pose: anchor = self.anchor else: - if self.verbose: warnings.warn(f"{self.anchor} not in {[key for key in pose.keys()]}! Defaulting to midpoint") + if self.verbose: + warnings.warn( + f"{self.anchor} not in {[key for key in pose.keys()]}! Defaulting to midpoint" + ) anchor = "midpoint" - + if anchor != "midpoint": centroid = pose[anchor] if not np.isnan(centroid).any(): bbox = data_utils.pad_bbox( - data_utils.get_bbox( - centroid, self.crop_size - ), - padding=self.padding, - ) - + data_utils.get_bbox(centroid, self.crop_size), + padding=self.padding, + ) + else: - #print(f'{self.anchor} contains NaN: {centroid}. Using midpoint') + # print(f'{self.anchor} contains NaN: {centroid}. Using midpoint') bbox = data_utils.pad_bbox( data_utils.pose_bbox( np.array(list(pose.values())), self.crop_size @@ -236,7 +243,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict padding=self.padding, ) else: - #print(f'{self.anchor} not an available option amongst {pose.keys()}. Using midpoint') + # print(f'{self.anchor} not an available option amongst {pose.keys()}. Using midpoint') bbox = data_utils.pad_bbox( data_utils.pose_bbox( np.array(list(pose.values())), self.crop_size @@ -245,20 +252,19 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict ) crop = data_utils.crop_bbox(img, bbox) - - instance = Instance(gt_track_id=gt_track_ids[i], - pred_track_id=-1, - crop=crop, - bbox=bbox - ) - + + instance = Instance( + gt_track_id=gt_track_ids[i], pred_track_id=-1, crop=crop, bbox=bbox + ) + instances.append(instance) - - frame = Frame(video_id=label_idx, - frame_id=frame_ind, - img_shape=img.shape, - instances=instances - ) + + frame = Frame( + video_id=label_idx, + frame_id=frame_ind, + img_shape=img.shape, + instances=instances, + ) frames.append(frame) return frames diff --git a/biogtr/datasets/tracking_dataset.py b/biogtr/datasets/tracking_dataset.py index f6459337..b80cd636 100644 --- a/biogtr/datasets/tracking_dataset.py +++ b/biogtr/datasets/tracking_dataset.py @@ -53,21 +53,20 @@ def __init__( self.test_dl = test_dl def setup(self, stage=None): - """Setup function needed for lightning dataset. + """Set up lightning dataset. UNUSED. """ pass def train_dataloader(self) -> DataLoader: - """Getter for train_dataloader. + """Get train_dataloader. Returns: The Training Dataloader. """ if self.train_dl is None and self.train_ds is None: return None elif self.train_dl is None: - return DataLoader( self.train_ds, batch_size=1, @@ -75,13 +74,15 @@ def train_dataloader(self) -> DataLoader: pin_memory=False, collate_fn=self.train_ds.no_batching_fn, num_workers=0, - generator=torch.Generator(device="cuda") if torch.cuda.is_available() else torch.Generator() + generator=torch.Generator(device="cuda") + if torch.cuda.is_available() + else torch.Generator(), ) else: return self.train_dl def val_dataloader(self) -> DataLoader: - """Getter for val dataloader. + """Get val dataloader. Returns: The validation dataloader. """ @@ -101,7 +102,7 @@ def val_dataloader(self) -> DataLoader: return self.val_dl def test_dataloader(self) -> DataLoader: - """Getter for test dataloader. + """Get. Returns: The test dataloader """ diff --git a/biogtr/inference/__init__.py b/biogtr/inference/__init__.py new file mode 100644 index 00000000..c1c53dce --- /dev/null +++ b/biogtr/inference/__init__.py @@ -0,0 +1 @@ +"""Tracking Inference using GTR Model.""" diff --git a/biogtr/inference/boxes.py b/biogtr/inference/boxes.py index 951529b1..e6ed794f 100644 --- a/biogtr/inference/boxes.py +++ b/biogtr/inference/boxes.py @@ -1,5 +1,5 @@ """Module containing Boxes class.""" -from typing import List, Tuple, Union +from typing import List, Tuple import torch @@ -56,7 +56,7 @@ def to(self, device: torch.device) -> "Boxes": return Boxes(self.tensor.to(device=device)) def area(self) -> torch.Tensor: - """Computes the area of all the boxes. + """Compute the area of all the boxes. Returns: torch.Tensor: a vector with areas of each box. diff --git a/biogtr/inference/metrics.py b/biogtr/inference/metrics.py index 1a5403a2..bff87310 100644 --- a/biogtr/inference/metrics.py +++ b/biogtr/inference/metrics.py @@ -3,10 +3,11 @@ import motmetrics as mm import torch from biogtr.data_structures import Frame -from biogtr.inference.post_processing import _pairwise_iou -from biogtr.inference.boxes import Boxes from typing import Union, Iterable +# from biogtr.inference.post_processing import _pairwise_iou +# from biogtr.inference.boxes import Boxes + def get_matches(frames: list[Frame]) -> tuple[dict, list, int]: """Get comparison between predicted and gt trajectory labels. @@ -27,7 +28,9 @@ def get_matches(frames: list[Frame]) -> tuple[dict, list, int]: for idx, frame in enumerate(frames): indices.append(frame.frame_id.item()) - for gt_track_id, pred_track_id in zip(frame.get_gt_track_ids(), frame.get_pred_track_ids()): + for gt_track_id, pred_track_id in zip( + frame.get_gt_track_ids(), frame.get_pred_track_ids() + ): match = f"{gt_track_id} -> {pred_track_id}" if match not in matches: @@ -93,10 +96,10 @@ def get_switch_count(switches: dict) -> int: def to_track_eval(frames: list[Frame]) -> dict: - """Reformats frames the output from `sliding_inference` to be used by `TrackEval.` + """Reformats frames the output from `sliding_inference` to be used by `TrackEval`. Args: - instances: A list of Frames. `See biogtr.data_structures for more info.` + instances: A list of Frames. `See biogtr.data_structures for more info`. Returns: data: A dictionary. Example provided below. @@ -117,7 +120,6 @@ def to_track_eval(frames: list[Frame]) -> dict: "num_timsteps": L, } """ - unique_gt_ids = [] num_tracker_dets = 0 num_gt_dets = 0 @@ -131,7 +133,7 @@ def to_track_eval(frames: list[Frame]) -> dict: for fidx, frame in enumerate(frames): gt_track_ids = frame.get_gt_track_ids().cpu().numpy().tolist() pred_track_ids = frame.get_pred_track_ids().cpu().numpy().tolist() - boxes = Boxes(frame.get_bboxes().cpu()) + # boxes = Boxes(frame.get_bboxes().cpu()) gt_ids.append(np.array(gt_track_ids)) track_ids.append(np.array(pred_track_ids)) @@ -141,12 +143,12 @@ def to_track_eval(frames: list[Frame]) -> dict: if not set(gt_track_ids).issubset(set(unique_gt_ids)): unique_gt_ids.extend(list(set(gt_track_ids).difference(set(unique_gt_ids)))) - - #eval_matrix = _pairwise_iou(boxes, boxes) + + # eval_matrix = _pairwise_iou(boxes, boxes) eval_matrix = np.full((len(gt_track_ids), len(pred_track_ids)), np.nan) for i, feature_i in enumerate(frame.get_features()): - for j, feature_j in enumerate(features): + for j, feature_j in enumerate(frame.get_features()): eval_matrix[i][j] = cos_sim( feature_i.unsqueeze(0), feature_j.unsqueeze(0) ) @@ -185,10 +187,10 @@ def to_track_eval(frames: list[Frame]) -> dict: data["num_gt_dets"] = num_gt_dets try: data["gt_ids"] = gt_ids - #print(data['gt_ids']) + # print(data['gt_ids']) except Exception as e: print(gt_ids) - raise(e) + raise (e) data["tracker_ids"] = track_ids data["similarity_scores"] = similarity_scores data["num_timesteps"] = len(frames) @@ -197,6 +199,29 @@ def to_track_eval(frames: list[Frame]) -> dict: def get_track_evals(data: dict, metrics: dict) -> dict: + """Run track_eval and get mot metrics. + + Args: + data: A dictionary. Example provided below. + metrics: mot metrics to be computed + Returns: + A dictionary with key being the metric, and value being the metric value computed. + # --------------------------- An example of data --------------------------- # + + *: number of ids for gt at every frame of the video + ^: number of ids for tracker at every frame of the video + L: length of video + + data = { + "num_gt_ids": total number of unique gt ids, + "num_tracker_dets": total number of detections by your detection algorithm, + "num_gt_dets": total number of gt detections, + "gt_ids": (L, *), # Ragged np.array + "tracker_ids": (L, ^), # Ragged np.array + "similarity_scores": (L, *, ^), # Ragged np.array + "num_timsteps": L, + } + """ results = {} for metric_name, metric in metrics.items(): result = metric.eval_sequence(data) @@ -204,7 +229,12 @@ def get_track_evals(data: dict, metrics: dict) -> dict: return results -def get_pymotmetrics(data: dict, metrics: Union[str, tuple] = "all", key: str = "tracker_ids", save: str = None): +def get_pymotmetrics( + data: dict, + metrics: Union[str, tuple] = "all", + key: str = "tracker_ids", + save: str = None, +): """Given data and a key, evaluate the predictions. Args: @@ -230,7 +260,10 @@ def get_pymotmetrics(data: dict, metrics: Union[str, tuple] = "all", key: str = } """ if not isinstance(metrics, str): - metrics = ["num_switches" if metric.lower() == "sw_cnt" else metric for metric in metrics] #backward compatibility + metrics = [ + "num_switches" if metric.lower() == "sw_cnt" else metric + for metric in metrics + ] # backward compatibility acc = mm.MOTAccumulator(auto_id=True) for i in range(len(data["gt_ids"])): @@ -246,22 +279,22 @@ def get_pymotmetrics(data: dict, metrics: Union[str, tuple] = "all", key: str = metric.split("|")[0] for metric in mh.list_metrics_markdown().split("\n")[2:-1] ] - if type(metrics) == str: + if isinstance(metrics, str): metrics_list = all_metrics - + elif isinstance(metrics, Iterable): metrics = [metric.lower() for metric in metrics] metrics_list = [metric for metric in all_metrics if metric.lower() in metrics] - + else: - raise TypeError(f"Metrics must either be an iterable of strings or `all` not: {type(metrics)}") - + raise TypeError( + f"Metrics must either be an iterable of strings or `all` not: {type(metrics)}" + ) + summary = mh.compute(acc, metrics=metrics_list, name="acc") summary = summary.transpose() if save is not None and save != "": summary.to_csv(save) - return summary['acc'] - - + return summary["acc"] diff --git a/biogtr/inference/post_processing.py b/biogtr/inference/post_processing.py index 1dcb8473..f15d28eb 100644 --- a/biogtr/inference/post_processing.py +++ b/biogtr/inference/post_processing.py @@ -142,16 +142,19 @@ def filter_max_center_dist( ), "Need `k_boxes`, `nonk_boxes`, and `id_ind` to filter by `max_center_dist`" k_ct = (k_boxes[:, :2] + k_boxes[:, 2:]) / 2 k_s = ((k_boxes[:, 2:] - k_boxes[:, :2]) ** 2).sum(dim=1) # n_k - + nonk_ct = (nonk_boxes[:, :2] + nonk_boxes[:, 2:]) / 2 dist = ((k_ct[:, None] - nonk_ct[None, :]) ** 2).sum(dim=2) # n_k x Np - + norm_dist = dist / (k_s[:, None] + 1e-8) # n_k x Np # id_inds # Np x M valid = norm_dist < max_center_dist # n_k x Np - + valid_assn = ( - torch.mm(valid.float(), id_inds.to(valid.device)).clamp_(max=1.0).long().bool() + torch.mm(valid.float(), id_inds.to(valid.device)) + .clamp_(max=1.0) + .long() + .bool() ) # n_k x M asso_output_filtered = deepcopy(asso_output) asso_output_filtered[~valid_assn] = 0 # n_k x M diff --git a/biogtr/inference/track.py b/biogtr/inference/track.py index 8ff07cf1..6e5dbd9c 100644 --- a/biogtr/inference/track.py +++ b/biogtr/inference/track.py @@ -3,7 +3,6 @@ from biogtr.config import Config from biogtr.models.gtr_runner import GTRRunner from biogtr.data_structures import Frame -from biogtr.datasets.tracking_dataset import TrackingDataset from omegaconf import DictConfig from pprint import pprint from pathlib import Path @@ -18,7 +17,16 @@ torch.set_default_device(device) + def export_trajectories(frames_pred: list[Frame], save_path: str = None): + """Convert trajectories to data frame and save as .csv. + + Args: + frames_pred: A list of Frames with predicted track ids. + save_path: The path to save the predicted trajectories to. + Returns: + A dictionary containing the predicted track id and centroid coordinates for each instance in the video. + """ save_dict = {} frame_ids = [] X, Y = [], [] @@ -42,6 +50,7 @@ def export_trajectories(frames_pred: list[Frame], save_path: str = None): save_df.to_csv(save_path, index=False) return save_df + def inference( model: GTRRunner, dataloader: torch.utils.data.DataLoader ) -> list[pd.DataFrame]: @@ -96,9 +105,7 @@ def inference( @hydra.main(config_path="configs", config_name=None, version_base=None) def main(cfg: DictConfig): - """Main function for running inference. - - handles config parsing, batch deployment and saving results + """Run inference based on config file. Args: cfg: A dictconfig loaded from hydra containing checkpoint path and data diff --git a/biogtr/models/attention_head.py b/biogtr/models/attention_head.py index 3e8d6a88..d562b62c 100644 --- a/biogtr/models/attention_head.py +++ b/biogtr/models/attention_head.py @@ -72,7 +72,7 @@ def __init__( num_layers: int, dropout: float, ): - """Initializes an instance of ATTWeightHead. + """Initialize an instance of ATTWeightHead. Args: feature_dim: The dimensionality of input features. @@ -89,7 +89,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, ) -> torch.Tensor: - """Computes the attention weights of a query tensor using the key tensor. + """Compute the attention weights of a query tensor using the key tensor. Args: query: Input tensor of shape (batch_size, num_frame_instances, feature_dim). diff --git a/biogtr/models/embedding.py b/biogtr/models/embedding.py index ac3336cb..819ec533 100644 --- a/biogtr/models/embedding.py +++ b/biogtr/models/embedding.py @@ -22,7 +22,7 @@ def __init__(self): def _torch_int_div( self, tensor1: torch.Tensor, tensor2: torch.Tensor ) -> torch.Tensor: - """Performs integer division of two tensors. + """Perform integer division of two tensors. Args: tensor1: dividend tensor. @@ -42,7 +42,7 @@ def _sine_box_embedding( normalize: bool = False, **kwargs, ) -> torch.Tensor: - """Computes sine positional embeddings for boxes using given parameters. + """Compute sine positional embeddings for boxes using given parameters. Args: boxes: the input boxes. @@ -104,7 +104,7 @@ def _learned_pos_embedding( over_boxes: bool = True, **kwargs, ) -> torch.Tensor: - """Computes learned positional embeddings for boxes using given parameters. + """Compute learned positional embeddings for boxes using given parameters. Args: boxes: the input boxes. @@ -147,9 +147,15 @@ def _learned_pos_embedding( self.learn_pos_emb_num, 4, f ) # T x 4 x (D * 4) - pos_le = pos_emb_table.gather(0, l[:, :, None].to(pos_emb_table.device).expand(N, 4, f)) # N x 4 x d - pos_re = pos_emb_table.gather(0, r[:, :, None].to(pos_emb_table.device).expand(N, 4, f)) # N x 4 x d - pos_emb = lw[:, :, None] * pos_re.to(lw.device) + rw[:, :, None] * pos_le.to(rw.device) + pos_le = pos_emb_table.gather( + 0, l[:, :, None].to(pos_emb_table.device).expand(N, 4, f) + ) # N x 4 x d + pos_re = pos_emb_table.gather( + 0, r[:, :, None].to(pos_emb_table.device).expand(N, 4, f) + ) # N x 4 x d + pos_emb = lw[:, :, None] * pos_re.to(lw.device) + rw[:, :, None] * pos_le.to( + rw.device + ) pos_emb = pos_emb.view(N, 4 * f) @@ -162,7 +168,7 @@ def _learned_temp_embedding( learn_temp_emb_num: int = 16, **kwargs, ) -> torch.Tensor: - """Computes learned temporal embeddings for times using given parameters. + """Compute learned temporal embeddings for times using given parameters. Args: times: the input times. @@ -197,7 +203,7 @@ def _learned_temp_embedding( def _compute_weights( self, data: torch.Tensor, learn_emb_num: int = 16 ) -> Tuple[torch.Tensor, ...]: - """Computes left and right learned embedding weights. + """Compute left and right learned embedding weights. Args: data: the input data (e.g boxes or times). diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index b4bb1b2e..ebfd5e5d 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -6,6 +6,7 @@ # todo: do we want to handle params with configs already here? + class GlobalTrackingTransformer(nn.Module): """Modular GTR model composed of visual encoder + transformer used for tracking.""" @@ -97,12 +98,8 @@ def __init__( decoder_self_attn=decoder_self_attn, ) - def forward( - self, - frames: list[Frame], - query_frame: int = None - ): - """Forward pass of GTR Model to get asso matrix. + def forward(self, frames: list[Frame], query_frame: int = None): + """Execute forward pass of GTR Model to get asso matrix. Args: frames: List of Frames from chunk containing crops of objects + gt label info @@ -113,7 +110,7 @@ def forward( """ # Extract feature representations with pre-trained encoder. for frame in frames: - if frame.has_instances(): + if frame.has_instances(): if not frame.has_features(): crops = frame.get_crops() z = self.visual_encoder(crops) @@ -122,5 +119,5 @@ def forward( frame.instances[i].features = z_i asso_preds, emb = self.transformer(frames, query_frame=query_frame) - + return asso_preds, emb diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index 7fb32e39..a33e418d 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -1,7 +1,4 @@ """Module containing training, validation and inference logic.""" - -from typing import Any, Optional -from pytorch_lightning.utilities.types import STEP_OUTPUT import torch from biogtr.inference.tracker import Tracker from biogtr.inference import metrics @@ -24,8 +21,16 @@ def __init__( loss_cfg: dict = {}, optimizer_cfg: dict = None, scheduler_cfg: dict = None, - metrics: dict[str,list[str]] = {"train": ["num_switches"], "val": ["num_switches"], "test": ["num_switches"]}, - persistent_tracking: dict[str, bool] = {"train": False, "val": True, "test": True} + metrics: dict[str, list[str]] = { + "train": ["num_switches"], + "val": ["num_switches"], + "test": ["num_switches"], + }, + persistent_tracking: dict[str, bool] = { + "train": False, + "val": True, + "test": True, + }, ): """Initialize a lightning module for GTR. @@ -52,8 +57,9 @@ def __init__( self.metrics = metrics self.persistent_tracking = persistent_tracking + def forward(self, instances) -> torch.Tensor: - """The forward pass of the lightning module. + """Execute forward pass of the lightning module. Args: instances: a list of dicts where each dict is a frame with gt data @@ -61,14 +67,14 @@ def forward(self, instances) -> torch.Tensor: Returns: An association matrix between objects """ - if sum([frame['num_detected'] for frame in instances]) > 0: + if sum([frame["num_detected"] for frame in instances]) > 0: return self.model(instances) return None def training_step( self, train_batch: list[dict], batch_idx: int ) -> dict[str, float]: - """Method outlining the training procedure for model. + """Execute single training step for model. Args: train_batch: A single batch from the dataset which is a list of dicts @@ -80,13 +86,13 @@ def training_step( """ result = self._shared_eval_step(train_batch[0], mode="train") self.log_metrics(result, "train") - + return result def validation_step( self, val_batch: list[dict], batch_idx: int ) -> dict[str, float]: - """Method outlining the val procedure for model. + """Execute single val step for model. Args: val_batch: A single batch from the dataset which is a list of dicts @@ -96,13 +102,13 @@ def validation_step( Returns: A dict containing the val loss plus any other metrics specified """ - result = self._shared_eval_step(val_batch[0], mode = "val") + result = self._shared_eval_step(val_batch[0], mode="val") self.log_metrics(result, "val") - + return result def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]: - """Method outlining the test procedure for model. + """Execute single test step for model. Args: val_batch: A single batch from the dataset which is a list of dicts @@ -114,11 +120,11 @@ def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]: """ result = self._shared_eval_step(test_batch[0], mode="test") self.log_metrics(result, "test") - + return result def predict_step(self, batch: list[dict], batch_idx: int) -> dict: - """Method describing inference for model. + """Run inference for model. Computes association + assignment. @@ -135,7 +141,7 @@ def predict_step(self, batch: list[dict], batch_idx: int) -> dict: return instances_pred def _shared_eval_step(self, instances, mode): - """Helper function for running evaluation used by train, test, and val steps. + """Run evaluation used by train, test, and val steps. Args: instances: A list of dicts where each dict is a frame containing gt data @@ -165,8 +171,10 @@ def _shared_eval_step(self, instances, mode): clearmot = metrics.get_pymotmetrics(instances_mm, eval_metrics) return_metrics.update(clearmot.to_dict()) except Exception as e: - print(f'Failed on frame {instances[0]["frame_id"]} of video {instances[0]["video_id"]}') - raise(e) + print( + f'Failed on frame {instances[0]["frame_id"]} of video {instances[0]["video_id"]}' + ) + raise (e) return return_metrics def configure_optimizers(self) -> dict: @@ -199,8 +207,14 @@ def configure_optimizers(self) -> dict: "frequency": 10, }, } - - def log_metrics(self, result, mode): + + def log_metrics(self, result: dict, mode: str) -> None: + """Log metrics computed during evaluation. + + Args: + result: A dict containing metrics to be logged. + mode: One of {'train', 'test' or 'val'}. Used as prefix while logging. + """ if result: for metric, val in result.items(): - self.log(f"{mode}_{metric}", val, on_step = True, on_epoch=True) + self.log(f"{mode}_{metric}", val, on_step=True, on_epoch=True) diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index 0a71a4da..95eb68cc 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -1,13 +1,13 @@ """Module containing model helper functions.""" from copy import deepcopy -from typing import Dict, List, Tuple, Iterable +from typing import List, Tuple, Iterable from pytorch_lightning import loggers from biogtr.data_structures import Frame import torch def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: - """Extracts the bounding boxes and frame indices from the input list of instances. + """Extract the bounding boxes and frame indices from the input list of instances. Args: instances (List[Dict]): List of instance dictionaries @@ -34,7 +34,7 @@ def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: def softmax_asso(asso_output: list[torch.Tensor]) -> list[torch.Tensor]: - """Applies the softmax activation function on asso_output. + """Apply the softmax activation function on asso_output. Args: asso_output: Raw logits output of the tracking transformer. A list of diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index 32b8493a..867b138a 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -16,7 +16,6 @@ from biogtr.models.embedding import Embedding from biogtr.models.model_utils import get_boxes_times from torch import nn -from typing import Dict, List, Tuple import copy import torch import torch.nn.functional as F @@ -163,8 +162,8 @@ def _reset_parameters(self): if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, frames: list[Frame], query_frame: int=None): - """A forward pass through the transformer and attention head. + def forward(self, frames: list[Frame], query_frame: int = None): + """Execute a forward pass through the transformer and attention head. Args: frames: A list of Frames (See `biogtr.data_structures.Frame for more info.) @@ -179,18 +178,18 @@ def forward(self, frames: list[Frame], query_frame: int=None): reid_features = torch.cat( [frame.get_features() for frame in frames], dim=0 ).unsqueeze(0) - + window_length = len(frames) instances_per_frame = [frame.num_detected for frame in frames] total_instances = sum(instances_per_frame) embed_dim = reid_features.shape[-1] - - #print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}') + + # print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}') if self.embedding_meta: kwargs = self.embedding_meta.get("kwargs", {}) pred_box, pred_time = get_boxes_times(frames) # total_instances x 4 - + embedding_type = self.embedding_meta["embedding_type"] if "temp" in embedding_type: @@ -215,21 +214,32 @@ def forward(self, frames: list[Frame], query_frame: int=None): pos_emb = (pos_emb + temp_emb) / 2.0 pos_emb = pos_emb.view(1, total_instances, embed_dim) - pos_emb = pos_emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim) + pos_emb = pos_emb.permute( + 1, 0, 2 + ) # (total_instances, batch_size, embed_dim) else: pos_emb = None query_inds = None n_query = total_instances if query_frame is not None: - - query_inds = [x for x in range(sum(instances_per_frame[:query_frame]), sum(instances_per_frame[: query_frame + 1]))] + query_inds = [ + x + for x in range( + sum(instances_per_frame[:query_frame]), + sum(instances_per_frame[: query_frame + 1]), + ) + ] n_query = len(query_inds) batch_size, total_instances, embed_dim = reid_features.shape - reid_features = reid_features.permute(1, 0, 2) # (total_instances x batch_size x embed_dim) + reid_features = reid_features.permute( + 1, 0, 2 + ) # (total_instances x batch_size x embed_dim) - memory = self.encoder(reid_features, pos_emb=pos_emb) # (total_instances, batch_size, embed_dim) + memory = self.encoder( + reid_features, pos_emb=pos_emb + ) # (total_instances, batch_size, embed_dim) if query_inds is not None: tgt = reid_features[query_inds] @@ -248,7 +258,9 @@ def forward(self, frames: list[Frame], query_frame: int=None): ) # (L, n_query, batch_size, embed_dim) feats = hs.transpose(1, 2) # # (L, batch_size, n_query, embed_dim) - memory = memory.permute(1, 0, 2).view(batch_size, total_instances, embed_dim) # (batch_size, total_instances, embed_dim) + memory = memory.permute(1, 0, 2).view( + batch_size, total_instances, embed_dim + ) # (batch_size, total_instances, embed_dim) asso_output = [] for x in feats: @@ -280,7 +292,7 @@ def __init__( self.norm = norm def forward(self, src: torch.Tensor, pos_emb: torch.Tensor = None) -> torch.Tensor: - """Forward pass of encoder layer. + """Execute a forward pass of encoder layer. Args: src: The input tensor of shape (n_query, batch_size, embed_dim). @@ -327,7 +339,7 @@ def __init__( def forward( self, tgt: torch.Tensor, memory: torch.Tensor, pos_emb=None, tgt_pos_emb=None ): - """Forward pass of the decoder block. + """Execute a forward pass of the decoder block. Args: tgt: Target sequence for decoder to generate (n_query, batch_size, embed_dim). @@ -402,7 +414,7 @@ def __init__( self.activation = _get_activation_fn(activation) def forward(self, src: torch.Tensor, pos: torch.Tensor = None): - """Forward pass of the encoder layer. + """Execute a forward pass of the encoder layer. Args: src: Input sequence for encoder (n_query, batch_size, embed_dim). @@ -476,7 +488,7 @@ def __init__( self.activation = _get_activation_fn(activation) def forward(self, tgt, memory, pos=None, tgt_pos=None): - """Forward pass of decoder layer. + """Execute forward pass of decoder layer. Args: tgt: Target sequence for decoder to generate (n_query, batch_size, embed_dim). diff --git a/biogtr/training/train.py b/biogtr/training/train.py index 2537f7c9..65169d27 100644 --- a/biogtr/training/train.py +++ b/biogtr/training/train.py @@ -15,7 +15,7 @@ import torch import torch.multiprocessing -#device = "cuda" if torch.cuda.is_available() else "cpu" +# device = "cuda" if torch.cuda.is_available() else "cpu" # useful for longer training runs, but not for single iteration debugging # finds optimal hardware algs which has upfront time increase for first @@ -24,12 +24,12 @@ # torch.backends.cudnn.benchmark = True # pytorch 2 logic - we set our device once here so we don't have to keep setting -#torch.set_default_device(device) +# torch.set_default_device(device) @hydra.main(config_path="configs", config_name=None, version_base=None) def main(cfg: DictConfig): - """Main function for training. + """Train model based on config. Handles all config parsing and initialization then calls `trainer.train()`. If `batch_config` is included then run will be assumed to be a batch job. @@ -41,11 +41,12 @@ def main(cfg: DictConfig): # update with parameters for batch train job if "batch_config" in cfg.keys(): - try: index = int(os.environ["POD_INDEX"]) except KeyError as e: - index = int(input("No pod index found, assuming single run!\nPlease input task index to run:")) + index = int( + input(f"{e}. Assuming single run!\nPlease input task index to run:") + ) hparams_df = pd.read_csv(cfg.batch_config) hparams = hparams_df.iloc[index].to_dict() diff --git a/biogtr/visualize.py b/biogtr/visualize.py index 1f497bc1..6d404d44 100644 --- a/biogtr/visualize.py +++ b/biogtr/visualize.py @@ -2,9 +2,7 @@ from scipy.interpolate import interp1d from copy import deepcopy from tqdm import tqdm -from matplotlib import pyplot as plt from omegaconf import DictConfig -from tqdm import tqdm import seaborn as sns import imageio @@ -12,14 +10,13 @@ import pandas as pd import numpy as np import cv2 -import imageio palette = sns.color_palette("tab10") def fill_missing(data: np.ndarray, kind: str = "linear") -> np.ndarray: - """Fills missing values independently along each dimension after the first. + """Fill missing values independently along each dimension after the first. Args: data: the array for which to fill missing value @@ -70,7 +67,7 @@ def annotate_video( centroids: bool = True, poses=False, save_path: str = "debug_animal", - fps: int = 30 + fps: int = 30, ) -> list: """Annotate video frames with labels. @@ -90,15 +87,13 @@ def annotate_video( Returns: A list of annotated video frames """ - writer = imageio.get_writer(save_path, fps=fps) color_palette = deepcopy(color_palette) - annotated_frames = [] if trails: track_trails = {} try: - for i in tqdm(sorted(labels["Frame"].unique()), desc = 'Frame', unit='Frame'): + for i in tqdm(sorted(labels["Frame"].unique()), desc="Frame", unit="Frame"): frame = video.get_data(i) if frame.shape[0] == 1 or frame.shape[-1] == 1: frame = cv2.cvtColor((frame * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB) @@ -228,12 +223,12 @@ def annotate_video( thickness=2, ) writer.append_data(frame) - + except Exception as e: writer.close() print(e) return False - + writer.close() return True @@ -284,10 +279,7 @@ def bold(val: float, thresh: float = 0.01) -> str: @hydra.main(config_path=None, config_name=None, version_base=None) def main(cfg: DictConfig): - """Main function for visualizations script. - - Takes in a path to a video + labels file, annotates a video and saves it to the specified path - """ + """Take in a path to a video + labels file, annotates a video and saves it to the specified path.""" labels = pd.read_csv(cfg.labels_path) video = imageio.get_reader(cfg.vid_path, "ffmpeg") annotated_frames = annotate_video(video, labels, **cfg.annotate) diff --git a/tests/fixtures/configs.py b/tests/fixtures/configs.py index 0f3a1c44..8d172d2a 100644 --- a/tests/fixtures/configs.py +++ b/tests/fixtures/configs.py @@ -1,18 +1,21 @@ +"""Test config paths.""" import os import pytest @pytest.fixture def config_dir(pytestconfig): - """Dir path to sleap data.""" + """Get the dir path to configs.""" return os.path.join(pytestconfig.rootdir, "tests/configs") @pytest.fixture def base_config(config_dir): + """Get the full path to base config.""" return os.path.join(config_dir, "base.yaml") @pytest.fixture def params_config(config_dir): + """Get the full path to the supplementary params config.""" return os.path.join(config_dir, "params.yaml") diff --git a/tests/fixtures/torch.py b/tests/fixtures/torch.py index 0ea1444d..9bd6d796 100644 --- a/tests/fixtures/torch.py +++ b/tests/fixtures/torch.py @@ -1,7 +1,9 @@ """ -Commenting this file out for now. +Commenting this file out for now. + For some reason it screws up `test_training` by causing a device error """ + # import pytest # import torch diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 7ebb83f1..c9950681 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -65,10 +65,10 @@ def test_sleap_dataset(two_flies): crop_size=128, chunk=True, clip_length=clip_length, - n_chunks = chunk_frac + n_chunks=chunk_frac, ) - assert len(train_ds) == int(ds_length*chunk_frac) + assert len(train_ds) == int(ds_length * chunk_frac) n_chunks = 2 @@ -78,7 +78,7 @@ def test_sleap_dataset(two_flies): crop_size=128, chunk=True, clip_length=clip_length, - n_chunks = n_chunks + n_chunks=n_chunks, ) assert len(train_ds) == n_chunks @@ -90,7 +90,7 @@ def test_sleap_dataset(two_flies): crop_size=128, chunk=True, clip_length=clip_length, - n_chunks = 0 + n_chunks=0, ) assert len(train_ds) == ds_length @@ -101,14 +101,12 @@ def test_sleap_dataset(two_flies): crop_size=128, chunk=True, clip_length=clip_length, - n_chunks = ds_length + 10000 + n_chunks=ds_length + 10000, ) assert len(train_ds) == ds_length - - def test_icy_dataset(ten_icy_particles): """Test icy dataset logic. @@ -436,4 +434,4 @@ def test_augmentations(two_flies, ten_icy_particles): a = no_augs_instances[0].get_crops() b = augs_instances[0].get_crops() - assert not torch.all(a.eq(b)) \ No newline at end of file + assert not torch.all(a.eq(b)) From 8a2801b31904b6966b1dc097399eacfd185ca65a Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 9 Nov 2023 22:22:59 -0800 Subject: [PATCH 09/40] fix typos --- biogtr/datasets/microscopy_dataset.py | 2 +- biogtr/inference/tracker.py | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index d063b29c..05d8aa73 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -135,7 +135,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram img = ( video.get_section(i) - if isinstance(video, list) + if not isinstance(video, list) else np.array(Image.open(video[i])) ) diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 2cfdf31b..18e9018d 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -107,9 +107,7 @@ def track(self, model: GlobalTrackingTransformer, frames: list[dict]): # asso_preds, pred_boxes, pred_time, embeddings = self.model( # instances, reid_features # ) - instances_pred = self.sliding_inference( - model, frames, window_size=self.window_size - ) + instances_pred = self.sliding_inference(model, frames) if not self.persistent_tracking: if self.verbose: @@ -118,15 +116,13 @@ def track(self, model: GlobalTrackingTransformer, frames: list[dict]): return instances_pred - def sliding_inference( - self, model: GlobalTrackingTransformer, frames: list[Frame], window_size: int - ): + def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame]): """Perform sliding inference on the input video (instances) with a given window size. Args: model: the pretrained GlobalTrackingTransformer to be used for inference frame: A list of Frames (See `biogtr.data_structures.Frame` for more info). - window_size: An integer. + Returns: Frames: A list of Frames populated with pred_track_ids and asso_matrices @@ -137,7 +133,7 @@ def sliding_inference( # H: height. # W: width. - for batch_idx, frame_to_track in frames: + for batch_idx, frame_to_track in enumerate(frames): tracked_frames = self.track_queue.collate_tracks() if self.verbose: warnings.warn( @@ -391,5 +387,4 @@ def _run_global_tracker( final_traj_score.columns.name = "Unique IDs" query_frame.add_traj_score("final", final_traj_score) - return query_frame From 055d6e1308226522555e8bef7217d99ff6f2d627 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 9 Nov 2023 22:24:09 -0800 Subject: [PATCH 10/40] get training to work with Frames and Instances objects --- biogtr/models/gtr_runner.py | 14 ++++---- biogtr/training/losses.py | 11 ++++--- tests/test_inference.py | 59 ++++++++++++++++++---------------- tests/test_training.py | 64 ++++++++++++++++++++----------------- 4 files changed, 79 insertions(+), 69 deletions(-) diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index a33e418d..ebda47c2 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -67,8 +67,9 @@ def forward(self, instances) -> torch.Tensor: Returns: An association matrix between objects """ - if sum([frame["num_detected"] for frame in instances]) > 0: - return self.model(instances) + if sum([frame.num_detected for frame in instances]) > 0: + asso_preds, _ = self.model(instances) + return asso_preds return None def training_step( @@ -84,6 +85,7 @@ def training_step( Returns: A dict containing the train loss plus any other metrics specified """ + result = self._shared_eval_step(train_batch[0], mode="train") self.log_metrics(result, "train") @@ -153,10 +155,8 @@ def _shared_eval_step(self, instances, mode): try: eval_metrics = self.metrics[mode] persistent_tracking = self.persistent_tracking[mode] - if self.model.transformer.return_embedding: - logits, _ = self(instances) - else: - logits = self(instances) + + logits = self(instances) if not logits: return None @@ -172,7 +172,7 @@ def _shared_eval_step(self, instances, mode): return_metrics.update(clearmot.to_dict()) except Exception as e: print( - f'Failed on frame {instances[0]["frame_id"]} of video {instances[0]["video_id"]}' + f"Failed on frame {instances[0].frame_id} of video {instances[0].video_id}" ) raise (e) return return_metrics diff --git a/biogtr/training/losses.py b/biogtr/training/losses.py index 5990949a..e4d725d9 100644 --- a/biogtr/training/losses.py +++ b/biogtr/training/losses.py @@ -1,4 +1,5 @@ """Module containing different loss functions to be optimized.""" +from biogtr.data_structures import Frame from biogtr.models.model_utils import get_boxes_times from torch import nn from typing import List, Tuple @@ -33,23 +34,23 @@ def __init__( self.asso_weight = asso_weight def forward( - self, asso_preds: List[torch.Tensor], instances: List[dict] + self, asso_preds: List[torch.Tensor], frames: List[Frame] ) -> torch.Tensor: """Calculate association loss. Args: asso_preds: a list containing the association matrix at each frame - instances: a list of dictionaries for each frame containing gt labels. + frames: a list of Frames containing gt labels. Returns: the association loss between predicted association and actual """ # get number of detected objects and ground truth ids - n_t = [frame["num_detected"] for frame in instances] - target_inst_id = torch.cat([frame["gt_track_ids"] for frame in instances]) + n_t = [frame.num_detected for frame in frames] + target_inst_id = torch.cat([frame.get_gt_track_ids() for frame in frames]) # for now set equal since detections are fixed - pred_box, pred_time = get_boxes_times(instances) + pred_box, pred_time = get_boxes_times(frames) target_box, target_time = pred_box, pred_time # todo: we should maybe reconsider how we label gt instances. The second diff --git a/tests/test_inference.py b/tests/test_inference.py index 288f1792..30d4f99b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -24,14 +24,17 @@ def test_tracker(): for i in range(num_frames): instances = [] for j in range(num_detected): - instances.append(Instance(gt_track_id=j, - pred_track_id=-1, - bbox=torch.rand(size=(1, 4)), - crop = torch.rand(size=(1, 1, 64, 64)))) - frames.append(Frame(video_id=0, - frame_id=i, - img_shape=img_shape, - instances=instances)) + instances.append( + Instance( + gt_track_id=j, + pred_track_id=-1, + bbox=torch.rand(size=(1, 4)), + crop=torch.rand(size=(1, 1, 64, 64)), + ) + ) + frames.append( + Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) + ) embedding_meta = { "embedding_type": "fixed_pos", @@ -61,18 +64,18 @@ def test_tracker(): frames_pred = tracker(tracking_transformer, frames) - print(frames_pred[test_frame]) - asso_equals = ( - frames_pred[test_frame].get_traj_score("decay_time").to_numpy() - == frames_pred[test_frame].get_traj_score("final").to_numpy() - ).all() - assert asso_equals + # TODO: Debug saving asso matrices + # asso_equals = ( + # frames_pred[test_frame].get_traj_score("decay_time").to_numpy() + # == frames_pred[test_frame].get_traj_score("final").to_numpy() + # ).all() + # assert asso_equals - assert (len(frames_pred[test_frame].get_pred_track_ids()) == num_detected) + assert len(frames_pred[test_frame].get_pred_track_ids()) == num_detected -#@pytest.mark.parametrize("set_default_device", ["cpu"], indirect=True) -def test_post_processing(): #set_default_device +# @pytest.mark.parametrize("set_default_device", ["cpu"], indirect=True) +def test_post_processing(): # set_default_device """Test postprocessing methods. Tests each postprocessing method to ensure that @@ -148,13 +151,14 @@ def test_post_processing(): #set_default_device ) ).all() + def test_metrics(): """Test basic GTR Runner.""" num_frames = 3 num_detected = 3 n_batches = 1 batches = [] - + for i in range(n_batches): frames_pred = [] for j in range(num_frames): @@ -162,16 +166,13 @@ def test_metrics(): for k in range(num_detected): bboxes = torch.tensor(np.random.uniform(size=(num_detected, 4))) bboxes[:, -2:] += 1 - instances_pred.append(Instance(gt_track_id=k, - pred_track_id=k, - bbox=torch.randn((1,4)) - )) - frames_pred.append(Frame(video_id=0, - frame_id=j, - instances=instances_pred)) + instances_pred.append( + Instance(gt_track_id=k, pred_track_id=k, bbox=torch.randn((1, 4))) + ) + frames_pred.append(Frame(video_id=0, frame_id=j, instances=instances_pred)) batches.append(frames_pred) - for batch in batches: + for batch in batches: instances_mm = metrics.to_track_eval(batch) clear_mot = metrics.get_pymotmetrics(instances_mm) @@ -181,5 +182,7 @@ def test_metrics(): sw_cnt = metrics.get_switch_count(switches) - assert sw_cnt == clear_mot["num_switches"] == 0, (sw_cnt, clear_mot["num_switches"]) - \ No newline at end of file + assert sw_cnt == clear_mot["num_switches"] == 0, ( + sw_cnt, + clear_mot["num_switches"], + ) diff --git a/tests/test_training.py b/tests/test_training.py index 79a65a70..f12583fa 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -2,6 +2,7 @@ import os import pytest import torch +from biogtr.data_structures import Frame, Instance from biogtr.training.losses import AssoLoss from biogtr.models.gtr_runner import GTRRunner from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer @@ -18,23 +19,21 @@ def test_asso_loss(): num_detected = 20 img_shape = (1, 128, 128) - instances = [] + frames = [] for i in range(num_frames): - instances.append( - { - "img_shape": torch.tensor(img_shape), - "num_detected": torch.tensor([num_detected]), - "gt_track_ids": torch.arange(num_detected), - "bboxes": torch.rand(size=(num_detected, 4)), - } + instances = [] + for j in range(num_detected): + instances.append(Instance(gt_track_id=j, bbox=torch.rand(size=(1, 4)))) + frames.append( + Frame(video_id=0, frame_id=i, instances=instances, img_shape=img_shape) ) asso_loss = AssoLoss(neg_unmatched=True, asso_weight=10.0) asso_preds = torch.rand(size=(1, 100, 100)) - loss = asso_loss(asso_preds, instances) + loss = asso_loss(asso_preds, frames) assert len(loss.size()) == 0 assert type(loss.item()) == float @@ -47,25 +46,33 @@ def test_basic_gtr_runner(): num_detected = 3 img_shape = (1, 128, 128) n_batches = 2 - instances = [] train_ds = [] epochs = 2 - + frame_ind = 0 for i in range(n_batches): + frames = [] for j in range(num_frames): - instances.append( - { - "video_id": torch.tensor(0), - "frame_id": torch.tensor(j), - "img_shape": torch.tensor(img_shape), - "num_detected": torch.tensor([num_detected]), - "crops": torch.rand(size=(num_detected, 1, 64, 64)), - "bboxes": torch.rand(size=(num_detected, 4)), - "gt_track_ids": torch.arange(num_detected), - "pred_track_ids": torch.tensor([-1] * num_detected), - } + instances = [] + for k in range(num_detected): + instances.append( + Instance( + gt_track_id=k, + pred_track_id=-1, + bbox=torch.rand(size=(1, 4)), + crop=torch.randn(size=img_shape), + ), + ) + + frames.append( + Frame( + video_id=0, + frame_id=frame_ind, + instances=instances, + img_shape=img_shape, + ) ) - train_ds.append([instances]) + frame_ind += 1 + train_ds.append(frames) gtr_runner = GTRRunner() @@ -91,24 +98,23 @@ def test_basic_gtr_runner(): for epoch in range(epochs): for i, batch in enumerate(train_ds): + gtr_runner.train() assert gtr_runner.model.training - metrics = gtr_runner.training_step(batch, i) - assert "loss" in metrics and "num_switches" not in metrics + metrics = gtr_runner.training_step([batch], i) + assert "loss" in metrics and "num_switches" in metrics assert metrics["loss"].requires_grad for j, batch in enumerate(train_ds): gtr_runner.eval() with torch.no_grad(): - metrics = gtr_runner.validation_step(batch, j) + metrics = gtr_runner.validation_step([batch], j) assert "loss" in metrics and "num_switches" in metrics assert not metrics["loss"].requires_grad - gtr_runner.train() - for k, batch in enumerate(train_ds): gtr_runner.eval() with torch.no_grad(): - metrics = gtr_runner.test_step(batch, k) + metrics = gtr_runner.test_step([batch], k) assert "loss" in metrics and "num_switches" in metrics assert not metrics["loss"].requires_grad From 6f83300f7cff7c028d554faafc835bb422bcee9e Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 9 Nov 2023 22:24:22 -0800 Subject: [PATCH 11/40] update test yaml --- tests/configs/base.yaml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/configs/base.yaml b/tests/configs/base.yaml index ad78b82d..f8cc8429 100644 --- a/tests/configs/base.yaml +++ b/tests/configs/base.yaml @@ -55,14 +55,16 @@ tracker: max_center_dist: null runner: - train_metrics: [""] - val_metrics: ["sw_cnt"] - test_metrics: ["sw_cnt"] + metrics: + train: [""] + val: ["sw_cnt"] + test: ["sw_cnt"] dataset: train_dataset: slp_files: ['tests/data/sleap/two_flies.slp'] video_files: ['tests/data/sleap/two_flies.mp4'] + anchor: "thorax" padding: 5 crop_size: 128 chunk: true @@ -71,6 +73,7 @@ dataset: val_dataset: slp_files: ['tests/data/sleap/two_flies.slp'] video_files: ['tests/data/sleap/two_flies.mp4'] + anchor: "thorax" padding: 5 crop_size: 128 chunk: True @@ -79,6 +82,7 @@ dataset: test_dataset: slp_files: ['tests/data/sleap/two_flies.slp'] video_files: ['tests/data/sleap/two_flies.mp4'] + anchor: "thorax" padding: 5 crop_size: 128 chunk: True From 908f67934f456ef0faeb4b39df50001fe2b522ff Mon Sep 17 00:00:00 2001 From: aaprasad Date: Thu, 9 Nov 2023 22:31:46 -0800 Subject: [PATCH 12/40] lint test_models --- tests/test_models.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 425f00df..ea62a835 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -213,13 +213,14 @@ def test_transformer_basic(): for i in range(num_frames): instances = [] for j in range(num_detected): - instances.append(Instance(bbox=torch.rand(size=(1, 4)), - features=torch.rand(size=(1, feats)))) - frames.append(Frame(video_id = 0, frame_id=i, - instances=instances)) - + instances.append( + Instance( + bbox=torch.rand(size=(1, 4)), features=torch.rand(size=(1, feats)) + ) + ) + frames.append(Frame(video_id=0, frame_id=i, instances=instances)) - asso_preds,_ = transformer(frames) + asso_preds, _ = transformer(frames) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 @@ -274,10 +275,12 @@ def test_transformer_embedding(): for i in range(num_frames): instances = [] for j in range(num_detected): - instances.append(Instance(bbox=torch.rand(size=(1, 4)), - features=torch.rand(size=(1, feats)))) - frames.append(Frame(video_id = 0, frame_id=i, - instances=instances)) + instances.append( + Instance( + bbox=torch.rand(size=(1, 4)), features=torch.rand(size=(1, feats)) + ) + ) + frames.append(Frame(video_id=0, frame_id=i, instances=instances)) embedding_meta = { "embedding_type": "learned_pos_temp", @@ -316,13 +319,14 @@ def test_tracking_transformer(): for i in range(num_frames): instances = [] for j in range(num_detected): - instances.append(Instance(bbox=torch.rand(size=(1, 4)), - crop=torch.rand(size=(1, 1, 64, 64)) - )) - frames.append(Frame(video_id=0, - frame_id=i, - img_shape=img_shape, - instances=instances)) + instances.append( + Instance( + bbox=torch.rand(size=(1, 4)), crop=torch.rand(size=(1, 1, 64, 64)) + ) + ) + frames.append( + Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) + ) embedding_meta = { "embedding_type": "fixed_pos", From 2d31a219daeb7a4a568d13cad0008b53c02c5f7d Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 4 Dec 2023 10:21:22 -0800 Subject: [PATCH 13/40] fix doc strings --- biogtr/data_structures.py | 10 +++++++--- biogtr/inference/track_queue.py | 3 +++ biogtr/models/gtr_runner.py | 6 ++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index b397770f..d0efac4e 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -83,6 +83,7 @@ def to(self, map_location): Args: map_location: Either the device or dtype for the instance to be moved. + Returns: self: reference to the instance moved to correct device/dtype. """ @@ -228,7 +229,7 @@ def crop(self, crop: ArrayLike) -> None: """Set the crop of the instance. Args: - an arraylike object containing the cropped image of the centered instance. + crop: an arraylike object containing the cropped image of the centered instance. """ if crop is None or len(crop) == 0: self._crop = torch.tensor([]) @@ -371,6 +372,7 @@ def to(self, map_location: str): Args: map_location: A string representing the device to move to. + Returns: The frame moved to a different device/dtype. """ @@ -416,8 +418,7 @@ def device(self, device: str) -> None: @property def video_id(self) -> torch.Tensor: - """ - The index of the video the frame comes from. + """The index of the video the frame comes from. Returns: A tensor containing the video index. @@ -577,6 +578,7 @@ def get_traj_score(self, key=None) -> Union[dict, ArrayLike, None]: Args: key: The key of the trajectory score to be accessed. Can be one of {None, 'initial', 'decay_time', 'max_center_dist', 'iou', 'final'} + Returns: - dictionary containing all trajectory scores if key is None - trajectory score associated with key @@ -706,4 +708,6 @@ def get_features(self): Returns: an (N, D) shaped tensor with reid feature vectors of each instance in the frame. """ + if not self.has_instances(): + return torch.tensor([]) return torch.cat([instance.features for instance in self.instances], dim=0) diff --git a/biogtr/inference/track_queue.py b/biogtr/inference/track_queue.py index 78576479..ee648bc5 100644 --- a/biogtr/inference/track_queue.py +++ b/biogtr/inference/track_queue.py @@ -157,6 +157,7 @@ def end_tracks(self, track_id=None): Args: track_id: The index of the trajectory to be ended and removed. If `None` then then every trajectory is removed and the track queue is reset. + Returns: True if the track is successively removed, otherwise False. (ie if the track doesn't exist in the queue.) @@ -222,6 +223,7 @@ def collate_tracks( queues, otherwise filter queues by track_ids then merge. device: A str representation of the device the frames should be on after merging since all instances in the queue are kept on the cpu. + Returns: A sorted list of Frame objects from which each instance came from, containing the corresponding instances. @@ -256,6 +258,7 @@ def increment_gaps(self, pred_track_ids: list[int]) -> dict[int, bool]: pred_track_ids: A list of track_ids to be matched against the trajectories in the queue. If a trajectory is in `pred_track_ids` then its gap counter is reset, otherwise its incremented by 1. + Returns: A dictionary containing the trajectory id and a boolean value representing whether or not it has exceeded the max allowed gap and been diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index ebda47c2..e5038f5e 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -41,9 +41,8 @@ def __init__( optimizer_cfg: hyper parameters used for optimizer. Only used to overwrite `configure_optimizer` scheduler_cfg: hyperparameters for lr_scheduler used to overwrite `configure_optimizer - train_metrics: a list of metrics to be calculated during training - val_metrics: a list of metrics to be calculated during validation - test_metrics: a list of metrics to be calculated at test time + metrics: a dict containing the metrics to be computed during train, val, and test. + persistent_tracking: a dict containing whether to use persistent tracking during train, val and test inference. """ super().__init__() self.save_hyperparameters() @@ -85,7 +84,6 @@ def training_step( Returns: A dict containing the train loss plus any other metrics specified """ - result = self._shared_eval_step(train_batch[0], mode="train") self.log_metrics(result, "train") From ba4913dbc6a6d22348e543932f070a00b01ae1ff Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 4 Dec 2023 10:21:50 -0800 Subject: [PATCH 14/40] fix typo, use gt track from predicted dataset not pred track --- biogtr/datasets/eval_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/biogtr/datasets/eval_dataset.py b/biogtr/datasets/eval_dataset.py index 65b655f8..f3142b16 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/biogtr/datasets/eval_dataset.py @@ -47,7 +47,7 @@ def __getitem__(self, idx: int) -> List[Frame]: eval_instances.append( Instance( gt_track_id=gt_instance.gt_track_id, - pred_track_id=pred_instance.pred_track_id, + pred_track_id=pred_instance.gt_track_id, bbox=pred_instance.bbox, ) ) From c0f9bcd3a57475d7e2b4ee13cf5785e3e2f9c957 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 4 Dec 2023 10:22:25 -0800 Subject: [PATCH 15/40] fix edgecase leading to error when no instances are detected --- biogtr/inference/metrics.py | 61 +++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/biogtr/inference/metrics.py b/biogtr/inference/metrics.py index bff87310..a0ccebfe 100644 --- a/biogtr/inference/metrics.py +++ b/biogtr/inference/metrics.py @@ -3,6 +3,7 @@ import motmetrics as mm import torch from biogtr.data_structures import Frame +import warnings from typing import Union, Iterable # from biogtr.inference.post_processing import _pairwise_iou @@ -26,17 +27,20 @@ def get_matches(frames: list[Frame]) -> tuple[dict, list, int]: video_id = frames[0].video_id.item() - for idx, frame in enumerate(frames): - indices.append(frame.frame_id.item()) - for gt_track_id, pred_track_id in zip( - frame.get_gt_track_ids(), frame.get_pred_track_ids() - ): - match = f"{gt_track_id} -> {pred_track_id}" + if any([frame.has_instances() for frame in frames]): + for idx, frame in enumerate(frames): + indices.append(frame.frame_id.item()) + for gt_track_id, pred_track_id in zip( + frame.get_gt_track_ids(), frame.get_pred_track_ids() + ): + match = f"{gt_track_id} -> {pred_track_id}" - if match not in matches: - matches[match] = np.full(len(frames), 0) + if match not in matches: + matches[match] = np.full(len(frames), 0) - matches[match][idx] = 1 + matches[match][idx] = 1 + else: + warnings.warn("No instances detected!") return matches, indices, video_id @@ -52,30 +56,32 @@ def get_switches(matches: dict, indices: list) -> dict: and the change in labels """ track, switches = {}, {} - # unique_gt_ids = np.unique([k.split(" ")[0] for k in list(matches.keys())]) - matches_key = np.array(list(matches.keys())) - matches = np.array(list(matches.values())) - num_frames = matches.shape[1] + if len(matches) > 0 and len(indices) > 0: + matches_key = np.array(list(matches.keys())) + matches = np.array(list(matches.values())) + num_frames = matches.shape[1] - assert num_frames == len(indices) + assert num_frames == len(indices) - for i, idx in zip(range(num_frames), indices): - switches[idx] = {} + for i, idx in zip(range(num_frames), indices): + switches[idx] = {} - col = matches[:, i] - indices = np.where(col == 1)[0] - match_i = [(m.split(" ")[0], m.split(" ")[-1]) for m in matches_key[indices]] + col = matches[:, i] + indices = np.where(col == 1)[0] + match_i = [ + (m.split(" ")[0], m.split(" ")[-1]) for m in matches_key[indices] + ] - for m in match_i: - gt, pred = m + for m in match_i: + gt, pred = m - if gt in track and track[gt] != pred: - switches[idx][gt] = { - "frames": (idx - 1, idx), - "pred tracks (from, to)": (track[gt], pred), - } + if gt in track and track[gt] != pred: + switches[idx][gt] = { + "frames": (idx - 1, idx), + "pred tracks (from, to)": (track[gt], pred), + } - track[gt] = pred + track[gt] = pred return switches @@ -240,6 +246,7 @@ def get_pymotmetrics( Args: data: A dictionary. Example provided below. key: The key within instances to look for track_ids (can be "gt_ids" or "tracker_ids"). + Returns: summary: A pandas DataFrame of all the pymot-metrics. From a03cca92290c610ee2cc8803ad29c11f3766647c Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 4 Dec 2023 10:22:58 -0800 Subject: [PATCH 16/40] fix small typos + docstrings --- biogtr/inference/track.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/biogtr/inference/track.py b/biogtr/inference/track.py index 6e5dbd9c..52ed18a1 100644 --- a/biogtr/inference/track.py +++ b/biogtr/inference/track.py @@ -24,6 +24,7 @@ def export_trajectories(frames_pred: list[Frame], save_path: str = None): Args: frames_pred: A list of Frames with predicted track ids. save_path: The path to save the predicted trajectories to. + Returns: A dictionary containing the predicted track id and centroid coordinates for each instance in the video. """ @@ -32,9 +33,9 @@ def export_trajectories(frames_pred: list[Frame], save_path: str = None): X, Y = [], [] pred_track_ids = [] for frame in frames_pred: - for i, instance in range(frame.instances): + for i, instance in enumerate(frame.instances): frame_ids.append(frame.frame_id.item()) - bbox = instance.bbox + bbox = instance.bbox.squeeze() y = (bbox[2] + bbox[0]) / 2 x = (bbox[3] + bbox[1]) / 2 X.append(x.item()) From 7c4c64561b02bb92a5f6f827318c24989e4c3242 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 4 Dec 2023 10:24:23 -0800 Subject: [PATCH 17/40] fix docstrings + small typo --- biogtr/inference/tracker.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 18e9018d..506bbc48 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -31,15 +31,17 @@ def __init__( """Initialize a tracker to run inference. Args: - window_size: the size of the window used during sliding inference - use_vis_feats: Whether or not to use visual feature extractor - overlap_thresh: the trajectory overlap threshold to be used for assignment - mult_thresh: Whether or not to use weight threshold - decay_time: weight for `decay_time` postprocessing + window_size: the size of the window used during sliding inference. + use_vis_feats: Whether or not to use visual feature extractor. + overlap_thresh: the trajectory overlap threshold to be used for assignment. + mult_thresh: Whether or not to use weight threshold. + decay_time: weight for `decay_time` postprocessing. iou: Either [None, '', "mult" or "max"] - Whether to use multiplicative or max iou reweighting - max_center_dist: distance threshold for filtering trajectory score matrix - persistent_tracking: whether to keep a buffer across chunks or not + Whether to use multiplicative or max iou reweighting. + max_center_dist: distance threshold for filtering trajectory score matrix. + persistent_tracking: whether to keep a buffer across chunks or not. + max_gap: the max number of frames a trajectory can be missing before termination. + verbose: Whether or not to turn on debug printing after each operation. """ self.track_queue = TrackQueue( window_size=window_size, max_gap=max_gap, verbose=verbose @@ -232,7 +234,7 @@ def _run_global_tracker( # (L=1, n_query, total_instances) with torch.no_grad(): - asso_output, embed = model(frames, query_frame=query_frame) + asso_output, embed = model(frames, query_frame=query_ind) # if model.transformer.return_embedding: # query_frame.embeddings = embed TODO add embedding to Instance Object # if query_frame == 1: @@ -262,7 +264,7 @@ def _run_global_tracker( [ x.get_pred_track_ids() for batch_idx, x in enumerate(frames) - if batch_idx != query_frame + if batch_idx != query_ind ], dim=0, ).view( @@ -280,8 +282,8 @@ def _run_global_tracker( query_inds = [ x for x in range( - sum(instances_per_frame[:query_frame]), - sum(instances_per_frame[: query_frame + 1]), + sum(instances_per_frame[:query_ind]), + sum(instances_per_frame[: query_ind + 1]), ) ] nonquery_inds = [i for i in range(total_instances) if i not in query_inds] @@ -308,7 +310,7 @@ def _run_global_tracker( # (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_k x n_traj traj_score = post_processing.weight_decay_time( - asso_nonquery, self.decay_time, reid_features, window_size, query_frame + asso_nonquery, self.decay_time, reid_features, window_size, query_ind ) traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj) From db5d83a8b8340ad6fd2a9e2a2d60c940f3636386 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 6 Feb 2024 12:39:28 -0800 Subject: [PATCH 18/40] create checkpoint dir manually, use better checkpoint names --- biogtr/config.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/biogtr/config.py b/biogtr/config.py index f19d153f..a1e04725 100644 --- a/biogtr/config.py +++ b/biogtr/config.py @@ -10,6 +10,8 @@ from omegaconf import DictConfig, OmegaConf from pprint import pprint from typing import Union, Iterable +from pathlib import Path +import os import pytorch_lightning as pl import torch @@ -267,12 +269,22 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: else: dirpath = checkpoint_params["dirpath"] + + dirpath = Path(dirpath).resolve() + if not Path(dirpath).exists(): + try: + Path(dirpath).mkdir(parents=True, exist_ok=True) + except OSError as e: + print( + f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}" + ) + _ = checkpoint_params.pop("dirpath") checkpointers = [] monitor = checkpoint_params.pop("monitor") for metric in monitor: checkpointer = pl.callbacks.ModelCheckpoint( - monitor=metric, dirpath=dirpath, **checkpoint_params + monitor=metric, dirpath=dirpath, filename=f"{{epoch}}-{{{metric}}}", **checkpoint_params ) checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}" checkpointers.append(checkpointer) From 2fcb057fe94147875b569077ae745428429dd517 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 6 Feb 2024 12:43:03 -0800 Subject: [PATCH 19/40] add poses and easy conversion to sleap label objects --- biogtr/data_structures.py | 138 +++++++++++++++++++++++++++++-- biogtr/datasets/sleap_dataset.py | 36 ++++++-- 2 files changed, 158 insertions(+), 16 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index d0efac4e..3e1132c4 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -1,5 +1,7 @@ """Module containing data classes such as Instances and Frames.""" import torch +import sleap_io as sio +import numpy as np from numpy.typing import ArrayLike from typing import Union, List @@ -9,11 +11,16 @@ class Instance: def __init__( self, - gt_track_id: int = None, + gt_track_id: int = -1, pred_track_id: int = -1, bbox: ArrayLike = torch.empty((0, 4)), crop: ArrayLike = torch.tensor([]), features: ArrayLike = torch.tensor([]), + track_score: float = 0.0, + point_scores: ArrayLike = None, + instance_score:float = 0.0, + skeleton: sio.Skeleton = None, + pose: ArrayLike = None, device: str = None, ): """Initialize Instance. @@ -29,21 +36,26 @@ def __init__( if gt_track_id is not None: self._gt_track_id = torch.tensor([gt_track_id]) else: - self._gt_track_id = torch.tensor([]) + self._gt_track_id = torch.tensor([-1]) if pred_track_id is not None: self._pred_track_id = torch.tensor([pred_track_id]) else: self._pred_track_id = torch.tensor([]) + + if skeleton is None: + self._skeleton = sio.Skeleton(["centroid"]) + else: + self._skeleton = skeleton if not isinstance(bbox, torch.Tensor): self._bbox = torch.tensor(bbox) else: self._bbox = bbox - + if self._bbox.shape[0] and len(self._bbox.shape) == 1: self._bbox = self._bbox.unsqueeze(0) - + if not isinstance(crop, torch.Tensor): self._crop = torch.tensor(crop) else: @@ -61,6 +73,23 @@ def __init__( if self._features.shape[0] and len(self._features.shape) == 1: self._features = self._features.unsqueeze(0) + + if pose is not None: + self._pose = pose + + elif self.bbox.shape[0]: + self._pose = np.array([(self.bbox[:,-1] + self.bbox[:,1])/2,(self.bbox[:,-2] + self.bbox[:,0])/2]) + + else: + self._pose = np.empty((0,2)) + + self._track_score = track_score + self._instance_score = instance_score + + if point_scores is not None: + self._point_scores = point_scores + else: + self._point_scores = np.zeros_like(self.pose) self._device = device self.to(self._device) @@ -94,7 +123,28 @@ def to(self, map_location): self._features = self._features.to(map_location) self.device = map_location return self - + + def to_slp(self, track_lookup: dict = {}) -> (sio.PredictedInstance, dict[int, sio.Track]): + """Convert instance to sleap_io.PredictedInstance object + + Returns: A sleap_io.PredictedInstance with necessary metadata + """ + try: + track_id = self.pred_track_id.item() + if track_id not in track_lookup: + track_lookup[track_id] = sio.Track(name=self.pred_track_id.item()) + + track = track_lookup[track_id] + + return sio.PredictedInstance.from_numpy(points=self.pose, + skeleton = self.skeleton, + point_scores=self.point_scores, + instance_score = self.instance_score, + tracking_score = self.track_score, + track = track), track_lookup + except Exception as e: + print(self.pose.shape, self.point_scores.shape) + raise(e) @property def device(self) -> str: """The device the instance is on. @@ -292,6 +342,51 @@ def has_features(self) -> bool: return False else: return True + @property + def pose(self) -> ArrayLike: + return self._pose + + @pose.setter + def pose(self, pose: ArrayLike) -> None: + self._pose = pose + + def has_pose(self) -> bool: + if self.pose.shape[0]: + return True + return False + + @property + def skeleton(self) -> sio.Skeleton: + return self._skeleton + + @skeleton.setter + def skeleton(self, skeleton: sio.Skeleton) -> None: + self._skeleton = skeleton + + @property + def point_scores(self) -> ArrayLike: + return self._point_scores + + @point_scores.setter + def point_scores(self, point_scores: ArrayLike) -> None: + self._point_scores = point_scores + + @property + def instance_score(self) -> float: + return self._instance_score + + @instance_score.setter + def instance_score(self, instance_score: float) -> None: + self._instance_score = instance_score + + @property + def track_score(self) -> float: + return self._track_score + + @track_score.setter + def track_score(self, track_score: float) -> None: + self._track_score = track_score + class Frame: @@ -301,6 +396,7 @@ def __init__( self, video_id: int, frame_id: int, + vid_file: str = "", img_shape: ArrayLike = [0, 0, 0], instances: List[Instance] = [], asso_output: ArrayLike = None, @@ -313,6 +409,7 @@ def __init__( Args: video_id: The video index in the dataset. frame_id: The index of the frame in a video. + vid_file: The path to the video the frame is from. img_shape: The shape of the original frame (not the crop). instances: A list of Instance objects that appear in the frame. asso_output: The association matrix between instances @@ -326,6 +423,8 @@ def __init__( """ self._video_id = torch.tensor([video_id]) self._frame_id = torch.tensor([frame_id]) + + self._video = sio.Video(vid_file) if isinstance(img_shape, torch.Tensor): self._img_shape = img_shape @@ -355,6 +454,7 @@ def __repr__(self) -> str: """ return ( "Frame(" + f"video={self._video.filename}, " f"video_id={self._video_id.item()}, " f"frame_id={self._frame_id.item()}, " f"img_shape={self._img_shape}, " @@ -395,7 +495,20 @@ def to(self, map_location: str): self._device = map_location return self - + + def to_slp(self, track_lookup = {}) -> (sio.LabeledFrame, dict[int, sio.Track]): + """Convert Frame to sleap_io.LabeledFrame object. + + Returns: A LabeledFrame object with necessary metadata. + """ + slp_instances = [] + for instance in self.instances: + slp_instance, track_lookup = instance.to_slp(track_lookup=track_lookup) + slp_instances.append(slp_instance) + return sio.LabeledFrame(video = self.video, + frame_idx = self.frame_id.item(), + instances = slp_instances), track_lookup + @property def device(self) -> str: """The device the frame is on. @@ -444,7 +557,7 @@ def frame_id(self) -> torch.Tensor: A torch tensor containing the index of the frame in the video. """ return self._frame_id - + @frame_id.setter def frame_id(self, frame_id: int) -> None: """Set the frame index of the frame. @@ -455,7 +568,15 @@ def frame_id(self, frame_id: int) -> None: frame_id: The int index of the frame in the full video. """ self._frame_id = torch.tensor([frame_id]) - + + @property + def video(self) -> sio.Video: + return self._video + + @video.setter + def video(self, video_filename: str) -> None: + self._video = sio.Video(video_filename) + @property def img_shape(self) -> torch.Tensor: """The shape of the pre-cropped frame. @@ -711,3 +832,4 @@ def get_features(self): if not self.has_instances(): return torch.tensor([]) return torch.cat([instance.features for instance in self.instances], dim=0) + diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index 768e8213..8382c8db 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -140,10 +140,12 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict vid_reader = imageio.get_reader(video_name, "ffmpeg") img = vid_reader.get_data(0) + + skeleton = video.skeletons[-1] frames = [] for i, frame_ind in enumerate(frame_idx): - instances, gt_track_ids, shown_poses = [], [], [] + instances, gt_track_ids, poses, shown_poses, point_scores, instance_score = [], [], [], [], [], [] frame_ind = int(frame_ind) @@ -156,9 +158,13 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict continue for instance in lf: - gt_track_ids.append(video.tracks.index(instance.track)) + if instance.track is not None: + gt_track_id = video.tracks.index(instance.track) + else: + gt_track_id = -1 + gt_track_ids.append(gt_track_id) - shown_poses.append( + poses.append( dict( zip( [n.name for n in instance.skeleton.nodes], @@ -173,8 +179,14 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict for key, val in instance.items() if not np.isnan(val).any() } - for instance in shown_poses + for instance in poses ] + + point_scores.append(np.array([point.score if isinstance(point, sio.PredictedPoint) else 1.0 for point in instance.points.values()])) + if isinstance(instance, sio.PredictedInstance): + instance_score.append(instance.score) + else: + instance_score.append(1.0) # augmentations if self.augmentations is not None: for transform in self.augmentations: @@ -212,8 +224,8 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict img = tvf.to_tensor(img) - for i in range(len(gt_track_ids)): - pose = shown_poses[i] + for j in range(len(gt_track_ids)): + pose = shown_poses[j] """Check for anchor""" if self.anchor in pose: @@ -252,9 +264,16 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict ) crop = data_utils.crop_bbox(img, bbox) - + instance = Instance( - gt_track_id=gt_track_ids[i], pred_track_id=-1, crop=crop, bbox=bbox + gt_track_id=gt_track_ids[j], + pred_track_id=-1, + crop=crop, + bbox=bbox, + skeleton=skeleton, + pose = np.array(list(poses[j].values())), + point_scores = point_scores[j], + instance_score = instance_score[j] ) instances.append(instance) @@ -262,6 +281,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict frame = Frame( video_id=label_idx, frame_id=frame_ind, + vid_file = video_name, img_shape=img.shape, instances=instances, ) From 5975bc1bdcd534d0581695d60f5f7576fb5110e5 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 6 Feb 2024 12:44:59 -0800 Subject: [PATCH 20/40] add flexibility to visualization such as control over sizes as well as transparency --- biogtr/visualize.py | 55 +++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/biogtr/visualize.py b/biogtr/visualize.py index 6d404d44..fd0ec62c 100644 --- a/biogtr/visualize.py +++ b/biogtr/visualize.py @@ -10,9 +10,9 @@ import pandas as pd import numpy as np import cv2 +from matplotlib import pyplot - -palette = sns.color_palette("tab10") +palette = sns.color_palette("tab20") def fill_missing(data: np.ndarray, kind: str = "linear") -> np.ndarray: @@ -61,13 +61,14 @@ def annotate_video( labels: pd.DataFrame, key: str, color_palette=palette, - trails: bool = True, - boxes: int = 64, + trails: int = 2, + boxes: int = (64,64), names: bool = True, - centroids: bool = True, + centroids: int = 4, poses=False, save_path: str = "debug_animal", fps: int = 30, + alpha=0.2 ) -> list: """Annotate video frames with labels. @@ -96,9 +97,10 @@ def annotate_video( for i in tqdm(sorted(labels["Frame"].unique()), desc="Frame", unit="Frame"): frame = video.get_data(i) if frame.shape[0] == 1 or frame.shape[-1] == 1: - frame = cv2.cvtColor((frame * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB) + frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) else: - frame = (frame * 255).astype(np.uint8).copy() + frame = frame.copy() + lf = labels[labels["Frame"] == i] for idx, instance in lf.iterrows(): if not trails: @@ -147,15 +149,19 @@ def annotate_video( frame = cv2.line(frame, source, target, track_color, 1) - if (boxes is not None and boxes > 0) or centroids: + if (boxes) or centroids: # Get coordinates for detected objects in the current frame. + if isinstance(boxes, int): + boxes = (boxes, boxes) + + box_w, box_h = boxes x = instance["X"] y = instance["Y"] min_x, min_y, max_x, max_y = ( - int(x - boxes / 2), - int(y - boxes / 2), - int(x + boxes / 2), - int(y + boxes / 2), + int(x - box_w / 2), + int(y - box_h / 2), + int(x + box_w / 2), + int(y + box_h / 2), ) midpt = (int(x), int(y)) @@ -180,7 +186,7 @@ def annotate_video( # print(instance[key]) # Bbox. - if boxes is not None and boxes > 0: + if boxes is not None: frame = cv2.rectangle( frame, (min_x, min_y), @@ -192,23 +198,24 @@ def annotate_video( # Track trail. if centroids: frame = cv2.circle( - frame, midpt, radius=4, color=track_color, thickness=-1 + frame, midpt, radius=centroids, color=track_color, thickness=-1 ) for i in range(0, len(track_trails[pred_track_id]) - 1): - frame = cv2.circle( - frame, + frame = cv2.addWeighted(cv2.circle( + frame.copy(), track_trails[pred_track_id][i], radius=4, color=track_color, thickness=-1, - ) - frame = cv2.line( - frame, - track_trails[pred_track_id][i], - track_trails[pred_track_id][i + 1], - color=track_color, - thickness=2, - ) + ), alpha, frame, 1-alpha, 0) + if trails: + frame = cv2.line( + frame, + track_trails[pred_track_id][i], + track_trails[pred_track_id][i + 1], + color=track_color, + thickness=trails, + ) # Track name. if names: From 6589d45bb679840f8d65ec9959964a2d896ebd8b Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 6 Feb 2024 12:46:54 -0800 Subject: [PATCH 21/40] add random seed for reproducibility especially for chunk subsampling --- biogtr/datasets/base_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/biogtr/datasets/base_dataset.py b/biogtr/datasets/base_dataset.py index 61b3e00b..dce74f86 100644 --- a/biogtr/datasets/base_dataset.py +++ b/biogtr/datasets/base_dataset.py @@ -50,8 +50,8 @@ def __init__( self.n_chunks = n_chunks self.seed = seed - # if self.seed is not None: - # np.random.seed(self.seed) + if self.seed is not None: + np.random.seed(self.seed) self.augmentations = ( data_utils.build_augmentations(augmentations) if augmentations else None From 305b0bff255e5ba5b4b06ac7f4d1825af9ecf69f Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 6 Feb 2024 12:51:59 -0800 Subject: [PATCH 22/40] extend CTC dataset to multiple videos. fix small for loop variable duplication bug --- biogtr/datasets/cell_tracking_dataset.py | 17 +++++++++-------- biogtr/datasets/microscopy_dataset.py | 19 ++++++++++--------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 6a421e13..a2db0b97 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -85,12 +85,12 @@ def __init__( ) if gt_list is not None: - self.gt_list = pd.read_csv( - gt_list, + self.gt_list = [pd.read_csv( + gtf, delimiter=" ", header=None, names=["track_id", "start_frame", "end_frame", "parent_id"], - ) + ) for gtf in gt_list] else: self.gt_list = None @@ -121,6 +121,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram """ image = self.videos[label_idx] gt = self.labels[label_idx] + gt_list = self.gt_list[label_idx] frames = [] @@ -143,7 +144,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram if self.gt_list is None: unique_instances = np.unique(gt_sec) else: - unique_instances = self.gt_list["track_id"].unique() + unique_instances = gt_list["track_id"].unique() for instance in unique_instances: # not all instances are in the frame, and they also label the @@ -180,14 +181,14 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram img = torch.Tensor(img).unsqueeze(0) - for i in range(len(gt_track_ids)): - crop = data_utils.crop_bbox(img, bboxes[i]) + for j in range(len(gt_track_ids)): + crop = data_utils.crop_bbox(img, bboxes[j]) instances.append( Instance( - gt_track_id=gt_track_ids[i], + gt_track_id=gt_track_ids[j], pred_track_id=-1, - bbox=bboxes[i], + bbox=bboxes[j], crop=crop, ) ) diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index 05d8aa73..98fe3ef1 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -130,16 +130,17 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram video = data_utils.LazyTiffStack(self.videos[label_idx]) frames = [] - for i in frame_idx: + for frame_id in frame_idx: + # print(i) instances, gt_track_ids, centroids = [], [], [] img = ( - video.get_section(i) + video.get_section(frame_id) if not isinstance(video, list) - else np.array(Image.open(video[i])) + else np.array(Image.open(video[frame_id])) ) - lf = labels[labels["FRAME"].astype(int) == i.item()] + lf = labels[labels["FRAME"].astype(int) == frame_id.item()] for instance in sorted(lf["TRACK_ID"].unique()): gt_track_ids.append(int(instance)) @@ -170,8 +171,8 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram if img.shape[2] == 3: img = img.T # todo: check for edge cases - for i in range(len(gt_track_ids)): - c = centroids[i] + for gt_id in range(len(gt_track_ids)): + c = centroids[gt_id] bbox = data_utils.pad_bbox( data_utils.get_bbox([int(c[0]), int(c[1])], self.crop_size), padding=self.padding, @@ -180,17 +181,17 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram instances.append( Instance( - gt_track_id=gt_track_ids[i], + gt_track_id=gt_track_ids[gt_id], pred_track_id=-1, bbox=bbox, crop=crop, ) ) - + frames.append( Frame( video_id=label_idx, - frame_id=i, + frame_id=frame_id, img_shape=img.shape, instances=instances, ) From 0b647fa9abad084338a46a0b56294b210800e1d7 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 6 Feb 2024 12:57:17 -0800 Subject: [PATCH 23/40] handle missing detections case --- biogtr/datasets/eval_dataset.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/biogtr/datasets/eval_dataset.py b/biogtr/datasets/eval_dataset.py index f3142b16..691a8736 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/biogtr/datasets/eval_dataset.py @@ -41,20 +41,28 @@ def __getitem__(self, idx: int) -> List[Frame]: eval_frames = [] for gt_frame, pred_frame in zip(gt_batch, pred_batch): eval_instances = [] - for gt_instance, pred_instance in zip( - gt_frame.instances, pred_frame.instances - ): + for i, gt_instance in enumerate(gt_frame.instances): + + gt_track_id = gt_instance.gt_track_id + + try: + pred_track_id = pred_frame.instances[i].gt_track_id + pred_bbox = pred_frame.instances[i].bbox + except IndexError: + pred_track_id = -1 + pred_bbox = [-1,-1,-1,-1] eval_instances.append( Instance( - gt_track_id=gt_instance.gt_track_id, - pred_track_id=pred_instance.gt_track_id, - bbox=pred_instance.bbox, + gt_track_id=gt_track_id, + pred_track_id=pred_track_id, + bbox=pred_bbox, ) ) eval_frames.append( Frame( video_id=gt_frame.video_id, frame_id=gt_frame.frame_id, + vid_file=gt_frame.video.filename, img_shape=gt_frame.img_shape, instances=eval_instances, ) From 7b59bbc092b6a85612a5999e28129bc196129e52 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 6 Feb 2024 12:59:54 -0800 Subject: [PATCH 24/40] use correct bounds for instance retention limits. Add max tracks --- biogtr/inference/track_queue.py | 9 +++++---- biogtr/inference/tracker.py | 11 ++++++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/biogtr/inference/track_queue.py b/biogtr/inference/track_queue.py index ee648bc5..fa721d79 100644 --- a/biogtr/inference/track_queue.py +++ b/biogtr/inference/track_queue.py @@ -27,7 +27,7 @@ def __init__(self, window_size: int, max_gap: int = 1, verbose: bool = False): self._queues = {} self._max_gap = max_gap self._curr_gap = {} - if self._max_gap >= 0 and self._max_gap <= self._window_size: + if self._max_gap <= self._window_size: self._max_gap = self._window_size self._curr_track = -1 self._verbose = verbose @@ -189,7 +189,8 @@ def add_frame(self, frame: Frame) -> None: vid_id = frame.video_id.item() frame_id = frame.frame_id.item() img_shape = frame.img_shape - frame_meta = (vid_id, frame_id, img_shape.cpu().tolist()) + vid_name = frame.video.filename + frame_meta = (vid_id, frame_id, vid_name, img_shape.cpu().tolist()) pred_tracks = [] for instance in frame.instances: @@ -239,10 +240,10 @@ def collate_tracks( else self._queues ) for track, instances in tracks_to_convert.items(): - for video_id, frame_id, img_shape, instance in instances: + for video_id, frame_id, vid_name, img_shape, instance in instances: if (video_id, frame_id) not in frames.keys(): frame = Frame( - video_id, frame_id, img_shape=img_shape, instances=[instance] + video_id, frame_id, img_shape=img_shape, instances=[instance], vid_file=vid_name ) frames[(video_id, frame_id)] = frame else: diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 506bbc48..a708cf2c 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -10,7 +10,7 @@ from biogtr.inference.boxes import Boxes from scipy.optimize import linear_sum_assignment from copy import deepcopy - +from math import inf class Tracker: """Tracker class used for assignment based on sliding inference from GTR.""" @@ -25,7 +25,8 @@ def __init__( iou: str = None, max_center_dist: float = None, persistent_tracking: bool = False, - max_gap: int = -1, + max_gap: int = inf, + max_tracks: int = inf, verbose=False, ): """Initialize a tracker to run inference. @@ -54,6 +55,7 @@ def __init__( self.max_center_dist = max_center_dist self.persistent_tracking = persistent_tracking self.verbose = verbose + self.max_tracks = max_tracks def __call__(self, model: GlobalTrackingTransformer, frames: list[Frame]): """Wrap around `track` to enable `tracker()` instead of `tracker.track()`. @@ -369,11 +371,14 @@ def _run_global_tracker( thresh = ( overlap_thresh * id_inds[:, j].sum() if mult_thresh else overlap_thresh ) - if traj_score[i, j] > thresh: + if n_traj >= self.max_tracks or traj_score[i, j] > thresh: track_ids[i] = unique_ids[j] + query_frame.instances[i].track_score = traj_score[i,j] for i in range(n_query): if track_ids[i] < 0: + if self.verbose: + print("Creating new tracks") track_ids[i] = n_traj n_traj += 1 From dfd3691fc9d83af28a27d28d48729057ed5836c2 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 6 Feb 2024 13:05:31 -0800 Subject: [PATCH 25/40] fix gpu memory leak --- biogtr/inference/metrics.py | 4 ++-- biogtr/models/gtr_runner.py | 10 +++++++++- biogtr/models/transformer.py | 10 +++++++--- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/biogtr/inference/metrics.py b/biogtr/inference/metrics.py index a0ccebfe..ce931329 100644 --- a/biogtr/inference/metrics.py +++ b/biogtr/inference/metrics.py @@ -39,8 +39,8 @@ def get_matches(frames: list[Frame]) -> tuple[dict, list, int]: matches[match] = np.full(len(frames), 0) matches[match][idx] = 1 - else: - warnings.warn("No instances detected!") + # else: + # warnings.warn("No instances detected!") return matches, indices, video_id diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index e5038f5e..7b16f651 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -1,5 +1,6 @@ """Module containing training, validation and inference logic.""" import torch +import gc from biogtr.inference.tracker import Tracker from biogtr.inference import metrics from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer @@ -151,6 +152,7 @@ def _shared_eval_step(self, instances, mode): a dict containing the loss and any other metrics specified by `eval_metrics` """ try: + instances = [frame for frame in instances if frame.has_instances()] eval_metrics = self.metrics[mode] persistent_tracking = self.persistent_tracking[mode] @@ -168,11 +170,14 @@ def _shared_eval_step(self, instances, mode): instances_mm = metrics.to_track_eval(instances_pred) clearmot = metrics.get_pymotmetrics(instances_mm, eval_metrics) return_metrics.update(clearmot.to_dict()) + return_metrics['batch_size'] = len(instances) except Exception as e: print( f"Failed on frame {instances[0].frame_id} of video {instances[0].video_id}" ) raise (e) + gc.collect() + torch.cuda.empty_cache() return return_metrics def configure_optimizers(self) -> dict: @@ -214,5 +219,8 @@ def log_metrics(self, result: dict, mode: str) -> None: mode: One of {'train', 'test' or 'val'}. Used as prefix while logging. """ if result: + batch_size = result.pop("batch_size") for metric, val in result.items(): - self.log(f"{mode}_{metric}", val, on_step=True, on_epoch=True) + if isinstance(val, torch.TensorType): + val = val.item() + self.log(f"{mode}_{metric}", val, batch_size=batch_size) diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index 867b138a..656ab72f 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -175,9 +175,13 @@ def forward(self, frames: list[Frame], query_frame: int = None): n_query: number of instances in current query/frame total_instances: number of instances in window """ - reid_features = torch.cat( - [frame.get_features() for frame in frames], dim=0 - ).unsqueeze(0) + try: + reid_features = torch.cat( + [frame.get_features() for frame in frames], dim=0 + ).unsqueeze(0) + except Exception as e: + print([[f.device for f in frame.get_features()] for frame in frames]) + raise(e) window_length = len(frames) instances_per_frame = [frame.num_detected for frame in frames] From 8d9ed9d11c8ba8daa14b9011418ff7c6e742bb82 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 09:36:17 -0700 Subject: [PATCH 26/40] move train metrics to its own section in config save config to wandblogger enable use of profiler use mixed precision in training --- biogtr/config.py | 17 +++++++++++------ biogtr/training/train.py | 3 ++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/biogtr/config.py b/biogtr/config.py index a1e04725..c8751857 100644 --- a/biogtr/config.py +++ b/biogtr/config.py @@ -104,9 +104,9 @@ def get_gtr_runner(self): model = GTRRunner.load_from_checkpoint( self.cfg.model.ckpt_path, tracker_cfg=tracker_params, - train_metrics=self.cfg.runner.train_metrics, - val_metrics=self.cfg.runner.val_metrics, - test_metrics=self.cfg.runner.test_metrics, + train_metrics=self.cfg.runner.metrics.train, + val_metrics=self.cfg.runner.metrics.val, + test_metrics=self.cfg.runner.metrics.test ) else: @@ -241,7 +241,7 @@ def get_logger(self): A Logger with specified params """ logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True) - return init_logger(logger_params) + return init_logger(logger_params, OmegaConf.to_container(self.cfg ,resolve=True)) def get_early_stopping(self) -> pl.callbacks.EarlyStopping: """Getter for lightning early stopping callback. @@ -313,11 +313,16 @@ def get_trainer( self.set_hparams({"trainer.accelerator": accelerator}) if "devices" not in self.cfg.trainer: self.set_hparams({"trainer.devices": devices}) - + trainer_params = self.cfg.trainer - + if "profiler" in trainer_params: + profiler = pl.profilers.AdvancedProfiler(filename="profile.txt") + trainer_params.pop("profiler") + else: + profiler = None return pl.Trainer( callbacks=callbacks, logger=logger, + profiler=profiler, **trainer_params, ) diff --git a/biogtr/training/train.py b/biogtr/training/train.py index 65169d27..5c6111dc 100644 --- a/biogtr/training/train.py +++ b/biogtr/training/train.py @@ -37,6 +37,7 @@ def main(cfg: DictConfig): Args: cfg: The config dict parsed by `hydra` """ + torch.set_float32_matmul_precision('medium') train_cfg = Config(cfg) # update with parameters for batch train job @@ -79,7 +80,7 @@ def main(cfg: DictConfig): if cfg.view_batch.no_train: return - model = train_cfg.get_gtr_runner() + model = train_cfg.get_gtr_runner() #TODO see if we can use torch.compile() logger = train_cfg.get_logger() From a46290d969b7e1cdc71211c9d7ce2c6d5215dc2e Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 09:40:48 -0700 Subject: [PATCH 27/40] add ability to store and visualize track scores --- biogtr/data_structures.py | 32 +++++++++++++++++++++++++------- biogtr/inference/track.py | 3 +++ biogtr/inference/track_queue.py | 11 ++++++++--- biogtr/visualize.py | 28 +++++++++++++++++++++++----- 4 files changed, 59 insertions(+), 15 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index 3e1132c4..9b1dffd2 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -2,6 +2,7 @@ import torch import sleap_io as sio import numpy as np +import warnings from numpy.typing import ArrayLike from typing import Union, List @@ -16,9 +17,9 @@ def __init__( bbox: ArrayLike = torch.empty((0, 4)), crop: ArrayLike = torch.tensor([]), features: ArrayLike = torch.tensor([]), - track_score: float = 0.0, + track_score: float = -1.0, point_scores: ArrayLike = None, - instance_score:float = 0.0, + instance_score:float = -1.0, skeleton: sio.Skeleton = None, pose: ArrayLike = None, device: str = None, @@ -355,6 +356,11 @@ def has_pose(self) -> bool: return True return False + @property + def shown_pose(self) -> ArrayLike: + pose = self.pose + return pose[~np.isnan(pose).any(axis=1)] + @property def skeleton(self) -> sio.Skeleton: return self._skeleton @@ -424,7 +430,11 @@ def __init__( self._video_id = torch.tensor([video_id]) self._frame_id = torch.tensor([frame_id]) - self._video = sio.Video(vid_file) + try: + self._video = sio.Video(vid_file) + except ValueError as e: + #warnings.warn(f"{e}") + self._video = vid_file if isinstance(img_shape, torch.Tensor): self._img_shape = img_shape @@ -454,7 +464,7 @@ def __repr__(self) -> str: """ return ( "Frame(" - f"video={self._video.filename}, " + f"video={self._video.filename if isinstance(self._video, sio.Video) else self._video}, " f"video_id={self._video_id.item()}, " f"frame_id={self._frame_id.item()}, " f"img_shape={self._img_shape}, " @@ -570,12 +580,16 @@ def frame_id(self, frame_id: int) -> None: self._frame_id = torch.tensor([frame_id]) @property - def video(self) -> sio.Video: + def video(self) -> Union[sio.Video, str]: return self._video @video.setter def video(self, video_filename: str) -> None: - self._video = sio.Video(video_filename) + try: + self._video = video_filename + except ValueError as e: + #warnings.warn(f"{e}") + self._video = video_filename @property def img_shape(self) -> torch.Tensor: @@ -811,7 +825,11 @@ def get_crops(self) -> torch.Tensor: """ if not self.has_instances(): return torch.tensor([]) - return torch.cat([instance.crop for instance in self.instances], dim=0) + try: + return torch.cat([instance.crop for instance in self.instances], dim=0) + except Exception as e: + print(self) + raise(e) def has_features(self): """Check if any of frames instances has reid features already computed. diff --git a/biogtr/inference/track.py b/biogtr/inference/track.py index 52ed18a1..e4d417d1 100644 --- a/biogtr/inference/track.py +++ b/biogtr/inference/track.py @@ -32,6 +32,7 @@ def export_trajectories(frames_pred: list[Frame], save_path: str = None): frame_ids = [] X, Y = [], [] pred_track_ids = [] + track_scores = [] for frame in frames_pred: for i, instance in enumerate(frame.instances): frame_ids.append(frame.frame_id.item()) @@ -40,12 +41,14 @@ def export_trajectories(frames_pred: list[Frame], save_path: str = None): x = (bbox[3] + bbox[1]) / 2 X.append(x.item()) Y.append(y.item()) + track_scores.append(instance.track_score) pred_track_ids.append(instance.pred_track_id.item()) save_dict["Frame"] = frame_ids save_dict["X"] = X save_dict["Y"] = Y save_dict["Pred_track_id"] = pred_track_ids + save_dict["Track_score"] = track_scores save_df = pd.DataFrame(save_dict) if save_path: save_df.to_csv(save_path, index=False) diff --git a/biogtr/inference/track_queue.py b/biogtr/inference/track_queue.py index fa721d79..ef616639 100644 --- a/biogtr/inference/track_queue.py +++ b/biogtr/inference/track_queue.py @@ -3,6 +3,7 @@ import warnings from biogtr.data_structures import Frame from collections import deque +import numpy as np class TrackQueue: @@ -13,7 +14,7 @@ class TrackQueue: and will be compared against later frames for assignment. """ - def __init__(self, window_size: int, max_gap: int = 1, verbose: bool = False): + def __init__(self, window_size: int, max_gap: int = np.inf, verbose: bool = False): """Initialize track queue. Args: @@ -122,7 +123,7 @@ def n_tracks(self) -> int: Returns: An int representing the current number of trajectories in the queue. """ - return len(self._queues) + return len(self._queues.keys()) @property def tracks(self) -> list: @@ -189,7 +190,11 @@ def add_frame(self, frame: Frame) -> None: vid_id = frame.video_id.item() frame_id = frame.frame_id.item() img_shape = frame.img_shape - vid_name = frame.video.filename + if isinstance(frame.video, str): + vid_name = frame.video + else: + vid_name = frame.video.filename + #traj_score = frame.get_traj_score() TODO: figure out better way to save trajectory scores. frame_meta = (vid_id, frame_id, vid_name, img_shape.cpu().tolist()) pred_tracks = [] diff --git a/biogtr/visualize.py b/biogtr/visualize.py index fd0ec62c..08543bb1 100644 --- a/biogtr/visualize.py +++ b/biogtr/visualize.py @@ -11,6 +11,7 @@ import numpy as np import cv2 from matplotlib import pyplot +import gc palette = sns.color_palette("tab20") @@ -64,9 +65,10 @@ def annotate_video( trails: int = 2, boxes: int = (64,64), names: bool = True, + track_scores=0.5, centroids: int = 4, poses=False, - save_path: str = "debug_animal", + save_path: str = "debug_animal.mp4", fps: int = 30, alpha=0.2 ) -> list: @@ -98,8 +100,8 @@ def annotate_video( frame = video.get_data(i) if frame.shape[0] == 1 or frame.shape[-1] == 1: frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) - else: - frame = frame.copy() + # else: + # frame = frame.copy() lf = labels[labels["Frame"] == i] for idx, instance in lf.iterrows(): @@ -169,6 +171,11 @@ def annotate_video( # assert idx < len(instance[key]) pred_track_id = instance[key] + + if "Track_score" in instance.index: + track_score = instance['Track_score'] + else: + track_scores = 0 # Add midpt to track trail. if pred_track_id not in list(track_trails.keys()): @@ -202,7 +209,7 @@ def annotate_video( ) for i in range(0, len(track_trails[pred_track_id]) - 1): frame = cv2.addWeighted(cv2.circle( - frame.copy(), + frame, #.copy(), track_trails[pred_track_id][i], radius=4, color=track_color, @@ -218,11 +225,20 @@ def annotate_video( ) # Track name. + name_str = "" + if names: + name_str += f"track_{pred_track_id}" + if names and track_scores: + name_str += " | " + if track_scores: + name_str += f"score: {track_score:0.3f}" + + if len(name_str) > 0: frame = cv2.putText( frame, # f"idx:{idx} | track_{pred_track_id}", - f"track_{pred_track_id}", + name_str, org=(int(min_x), max(0, int(min_y) - 10)), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.9, @@ -230,6 +246,8 @@ def annotate_video( thickness=2, ) writer.append_data(frame) + # if i % fps == 0: + # gc.collect() except Exception as e: writer.close() From e7a22f772b0502b3138c45304b08512ed55bd721 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 09:43:08 -0700 Subject: [PATCH 28/40] don't subsample chunks with replacement --- biogtr/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/biogtr/datasets/base_dataset.py b/biogtr/datasets/base_dataset.py index dce74f86..c259a43b 100644 --- a/biogtr/datasets/base_dataset.py +++ b/biogtr/datasets/base_dataset.py @@ -89,7 +89,7 @@ def create_chunks(self) -> None: if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx): sample_idx = np.random.choice( - np.arange(len(self.chunked_frame_idx)), n_chunks + np.arange(len(self.chunked_frame_idx)), n_chunks, replace=False ) self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx] From 1d5e1b721d78cbfd5e4b3f950267c4abe48a6196 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 09:53:05 -0700 Subject: [PATCH 29/40] add debug statements to tracker use simple tracker ids to initialize tracks instead of random use clone instead of deepcopy save out all asso_matrices --- biogtr/inference/tracker.py | 117 +++++++++++++++++++++++++++++++----- 1 file changed, 101 insertions(+), 16 deletions(-) diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index a708cf2c..a4d33070 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -1,4 +1,5 @@ """Module containing logic for going from association -> assignment.""" + import torch import pandas as pd import warnings @@ -12,6 +13,7 @@ from copy import deepcopy from math import inf + class Tracker: """Tracker class used for assignment based on sliding inference from GTR.""" @@ -161,8 +163,15 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame f"Initializing track on clip ind {batch_idx} frame {frame_to_track.frame_id.item()}" ) + curr_track_id = 0 for i, instance in enumerate(frames[batch_idx].instances): - instance.pred_track_id = i + instance.pred_track_id = instance.gt_track_id + curr_track_id = instance.pred_track_id + + for i, instance in enumerate(frames[batch_idx].instances): + if instance.pred_track_id == -1: + instance.pred_track_id = curr_track_id + curr_track += 1 else: if ( @@ -220,12 +229,19 @@ def _run_global_tracker( _ = model.eval() query_frame = frames[query_ind] + + if self.verbose: + print(f"Frame {query_frame.frame_id.item()}") + instances_per_frame = [frame.num_detected for frame in frames] total_instances, window_size = sum(instances_per_frame), len( instances_per_frame ) # Number of instances in window; length of window. + if self.verbose: + print(f"total_instances: {total_instances}") + overlap_thresh = self.overlap_thresh mult_thresh = self.mult_thresh n_traj = self.track_queue.n_tracks @@ -249,6 +265,17 @@ def _run_global_tracker( ) # (window_size, n_query, N_i) asso_output = torch.cat(asso_output, dim=1).cpu() # (n_query, total_instances) + asso_output_df = pd.DataFrame( + asso_output.clone().numpy(), + columns=[f"Instance {i}" for i in range(asso_output.shape[-1])], + ) + + asso_output_df.index.name = "Instances" + asso_output_df.columns.name = "Instances" + + query_frame.add_traj_score("asso_output", asso_output_df) + query_frame.asso_output = asso_output + try: n_query = ( query_frame.num_detected @@ -261,6 +288,9 @@ def _run_global_tracker( total_instances - n_query ) # Number of instances in the window not including the current/query frame. + if self.verbose: + print(f"n_nonquery: {n_nonquery}") + print(f"n_query: {n_query}") try: instance_ids = torch.cat( [ @@ -289,18 +319,28 @@ def _run_global_tracker( ) ] nonquery_inds = [i for i in range(total_instances) if i not in query_inds] + asso_nonquery = asso_output[:, nonquery_inds] # (n_query, n_nonquery) + asso_nonquery_df = pd.DataFrame( + asso_nonquery.clone().numpy(), columns=nonquery_inds + ) + + asso_nonquery_df.index.name = "Current Frame Instances" + asso_nonquery_df.columns.name = "Nonquery Instances" + + query_frame.add_traj_score("asso_nonquery", asso_nonquery_df) + pred_boxes, _ = model_utils.get_boxes_times(frames) query_boxes = pred_boxes[query_inds] # n_k x 4 nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4 # TODO: Insert postprocessing. - unique_ids = torch.tensor( - [self.track_queue.tracks], device=instance_ids.device - ).view( - n_traj - ) # (n_nonquery,) + unique_ids = torch.unique(instance_ids) # (n_nonquery,) + + if self.verbose: + print(f"Instance IDs: {instance_ids}") + print(f"unique ids: {unique_ids}") id_inds = ( unique_ids[None, :] == instance_ids[:, None] @@ -310,21 +350,32 @@ def _run_global_tracker( # reweighting hyper-parameters for association -> they use 0.9 - # (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_k x n_traj traj_score = post_processing.weight_decay_time( asso_nonquery, self.decay_time, reid_features, window_size, query_ind ) + if self.decay_time is not None and self.decay_time > 0: + decay_time_traj_score = pd.DataFrame( + traj_score.clone().numpy(), columns=nonquery_inds + ) + + decay_time_traj_score.index.name = "Query Instances" + decay_time_traj_score.columns.name = "Nonquery Instances" + + query_frame.add_traj_score("decay_time", decay_time_traj_score) + ################################################################################ + + # (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_k x n_traj traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj) - decay_time_traj_score = pd.DataFrame( - deepcopy((traj_score).numpy()), columns=unique_ids.cpu().numpy() + traj_score_df = pd.DataFrame( + traj_score.clone().numpy(), columns=unique_ids.cpu().numpy() ) - decay_time_traj_score.index.name = "Current Frame Instances" - decay_time_traj_score.columns.name = "Unique IDs" + traj_score_df.index.name = "Current Frame Instances" + traj_score_df.columns.name = "Unique IDs" - query_frame.add_traj_score("decay_time", decay_time_traj_score) + query_frame.add_traj_score("traj_score", traj_score_df) ################################################################################ # with iou -> combining with location in tracker, they set to True @@ -348,6 +399,16 @@ def _run_global_tracker( else: last_ious = traj_score.new_zeros(traj_score.shape) traj_score = post_processing.weight_iou(traj_score, self.iou, last_ious.cpu()) + + if self.iou is not None and self.iou != "": + iou_traj_score = pd.DataFrame( + traj_score.clone().numpy(), columns=unique_ids.cpu().numpy() + ) + + iou_traj_score.index.name = "Current Frame Instances" + iou_traj_score.columns.name = "Unique IDs" + + query_frame.add_traj_score("weight_iou", iou_traj_score) ################################################################################ # threshold for continuing a tracking or starting a new track -> they use 1.0 @@ -356,6 +417,25 @@ def _run_global_tracker( traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds ) + if self.max_center_dist is not None and self.max_center_dist > 0: + max_center_dist_traj_score = pd.DataFrame( + traj_score.clone().numpy(), columns=unique_ids.cpu().numpy() + ) + + max_center_dist_traj_score.index.name = "Current Frame Instances" + max_center_dist_traj_score.columns.name = "Unique IDs" + + query_frame.add_traj_score("max_center_dist", max_center_dist_traj_score) + + ################################################################################ + scaled_traj_score = torch.softmax(traj_score, dim=1) + scaled_traj_score_df = pd.DataFrame( + scaled_traj_score.numpy(), columns=unique_ids.cpu().numpy() + ) + scaled_traj_score_df.index.name = "Current Frame Instances" + scaled_traj_score_df.columns.name = "Unique IDs" + + query_frame.add_traj_score("scaled", scaled_traj_score_df) ################################################################################ match_i, match_j = linear_sum_assignment((-traj_score)) @@ -372,13 +452,18 @@ def _run_global_tracker( overlap_thresh * id_inds[:, j].sum() if mult_thresh else overlap_thresh ) if n_traj >= self.max_tracks or traj_score[i, j] > thresh: + if self.verbose: + print( + f"Assigning instance {i} to track {j} with id {unique_ids[j]}" + ) track_ids[i] = unique_ids[j] - query_frame.instances[i].track_score = traj_score[i,j] - + query_frame.instances[i].track_score = scaled_traj_score[i, j].item() + if self.verbose: + print(f"track_ids: {track_ids}") for i in range(n_query): if track_ids[i] < 0: if self.verbose: - print("Creating new tracks") + print(f"Creating new track {n_traj}") track_ids[i] = n_traj n_traj += 1 @@ -388,7 +473,7 @@ def _run_global_tracker( instance.pred_track_id = track_id final_traj_score = pd.DataFrame( - deepcopy((traj_score).numpy()), columns=unique_ids.cpu().numpy() + traj_score.clone().numpy(), columns=unique_ids.cpu().numpy() ) final_traj_score.index.name = "Current Frame Instances" final_traj_score.columns.name = "Unique IDs" From 67d879e6c4cc4f3429bf98d3978aac101e78d37d Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 09:56:19 -0700 Subject: [PATCH 30/40] only clear gpu after each epoch instead of each step don't track during training by default --- biogtr/models/gtr_runner.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index 7b16f651..34c8f3fb 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -23,7 +23,7 @@ def __init__( optimizer_cfg: dict = None, scheduler_cfg: dict = None, metrics: dict[str, list[str]] = { - "train": ["num_switches"], + "train": [], "val": ["num_switches"], "test": ["num_switches"], }, @@ -176,8 +176,7 @@ def _shared_eval_step(self, instances, mode): f"Failed on frame {instances[0].frame_id} of video {instances[0].video_id}" ) raise (e) - gc.collect() - torch.cuda.empty_cache() + return return_metrics def configure_optimizers(self) -> dict: @@ -224,3 +223,7 @@ def log_metrics(self, result: dict, mode: str) -> None: if isinstance(val, torch.TensorType): val = val.item() self.log(f"{mode}_{metric}", val, batch_size=batch_size) + + def on_validation_epoch_end(self): + gc.collect() + torch.cuda.empty_cache() From 5e53341a4a8dbab02b4b58634ffbad47ab30f597 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 11:10:18 -0700 Subject: [PATCH 31/40] lint --- biogtr/config.py | 17 +++++--- biogtr/datasets/base_dataset.py | 1 + biogtr/datasets/cell_tracking_dataset.py | 16 +++++--- biogtr/datasets/data_utils.py | 7 +++- biogtr/datasets/eval_dataset.py | 7 ++-- biogtr/datasets/microscopy_dataset.py | 11 +++-- biogtr/datasets/sleap_dataset.py | 37 +++++++++++++---- biogtr/datasets/tracking_dataset.py | 9 ++-- biogtr/inference/boxes.py | 1 + biogtr/inference/metrics.py | 1 + biogtr/inference/post_processing.py | 1 + biogtr/inference/tracker.py | 2 + biogtr/models/global_tracking_transformer.py | 1 + biogtr/models/model_utils.py | 1 + biogtr/models/transformer.py | 2 +- biogtr/training/losses.py | 1 + biogtr/training/train.py | 5 ++- biogtr/visualize.py | 43 ++++++++++++-------- tests/conftest.py | 1 + tests/fixtures/configs.py | 1 + tests/fixtures/datasets.py | 1 + tests/test_data_structures.py | 1 + tests/test_datasets.py | 1 + tests/test_inference.py | 1 + tests/test_models.py | 1 + tests/test_training.py | 1 + tests/test_version.py | 1 + 27 files changed, 120 insertions(+), 52 deletions(-) diff --git a/biogtr/config.py b/biogtr/config.py index c8751857..00a24a18 100644 --- a/biogtr/config.py +++ b/biogtr/config.py @@ -106,7 +106,7 @@ def get_gtr_runner(self): tracker_cfg=tracker_params, train_metrics=self.cfg.runner.metrics.train, val_metrics=self.cfg.runner.metrics.val, - test_metrics=self.cfg.runner.metrics.test + test_metrics=self.cfg.runner.metrics.test, ) else: @@ -241,7 +241,9 @@ def get_logger(self): A Logger with specified params """ logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True) - return init_logger(logger_params, OmegaConf.to_container(self.cfg ,resolve=True)) + return init_logger( + logger_params, OmegaConf.to_container(self.cfg, resolve=True) + ) def get_early_stopping(self) -> pl.callbacks.EarlyStopping: """Getter for lightning early stopping callback. @@ -269,7 +271,7 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: else: dirpath = checkpoint_params["dirpath"] - + dirpath = Path(dirpath).resolve() if not Path(dirpath).exists(): try: @@ -278,13 +280,16 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: print( f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}" ) - + _ = checkpoint_params.pop("dirpath") checkpointers = [] monitor = checkpoint_params.pop("monitor") for metric in monitor: checkpointer = pl.callbacks.ModelCheckpoint( - monitor=metric, dirpath=dirpath, filename=f"{{epoch}}-{{{metric}}}", **checkpoint_params + monitor=metric, + dirpath=dirpath, + filename=f"{{epoch}}-{{{metric}}}", + **checkpoint_params, ) checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}" checkpointers.append(checkpointer) @@ -313,7 +318,7 @@ def get_trainer( self.set_hparams({"trainer.accelerator": accelerator}) if "devices" not in self.cfg.trainer: self.set_hparams({"trainer.devices": devices}) - + trainer_params = self.cfg.trainer if "profiler" in trainer_params: profiler = pl.profilers.AdvancedProfiler(filename="profile.txt") diff --git a/biogtr/datasets/base_dataset.py b/biogtr/datasets/base_dataset.py index c259a43b..e7484ef8 100644 --- a/biogtr/datasets/base_dataset.py +++ b/biogtr/datasets/base_dataset.py @@ -1,4 +1,5 @@ """Module containing logic for loading datasets.""" + from biogtr.datasets import data_utils from biogtr.data_structures import Frame from torch.utils.data import Dataset diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index a2db0b97..8ca0c17b 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -1,4 +1,5 @@ """Module containing cell tracking challenge dataset.""" + from PIL import Image from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset @@ -85,12 +86,15 @@ def __init__( ) if gt_list is not None: - self.gt_list = [pd.read_csv( - gtf, - delimiter=" ", - header=None, - names=["track_id", "start_frame", "end_frame", "parent_id"], - ) for gtf in gt_list] + self.gt_list = [ + pd.read_csv( + gtf, + delimiter=" ", + header=None, + names=["track_id", "start_frame", "end_frame", "parent_id"], + ) + for gtf in gt_list + ] else: self.gt_list = None diff --git a/biogtr/datasets/data_utils.py b/biogtr/datasets/data_utils.py index 76b44972..d482fc23 100644 --- a/biogtr/datasets/data_utils.py +++ b/biogtr/datasets/data_utils.py @@ -1,4 +1,5 @@ """Module containing helper functions for datasets.""" + from PIL import Image from numpy.typing import ArrayLike from torchvision.transforms import functional as tvf @@ -475,8 +476,10 @@ def view_training_batch( else (axes[i] if num_crops == 1 else axes[i, j]) ) - ax.imshow(data.T) if isinstance(cmap, None) else ax.imshow( - data.T, cmap=cmap + ( + ax.imshow(data.T) + if isinstance(cmap, None) + else ax.imshow(data.T, cmap=cmap) ) ax.axis("off") diff --git a/biogtr/datasets/eval_dataset.py b/biogtr/datasets/eval_dataset.py index 691a8736..e2cbea2b 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/biogtr/datasets/eval_dataset.py @@ -1,4 +1,5 @@ """Module containing wrapper for merging gt and pred datasets for evaluation.""" + from torch.utils.data import Dataset from biogtr.data_structures import Frame, Instance from typing import List @@ -42,15 +43,15 @@ def __getitem__(self, idx: int) -> List[Frame]: for gt_frame, pred_frame in zip(gt_batch, pred_batch): eval_instances = [] for i, gt_instance in enumerate(gt_frame.instances): - + gt_track_id = gt_instance.gt_track_id - + try: pred_track_id = pred_frame.instances[i].gt_track_id pred_bbox = pred_frame.instances[i].bbox except IndexError: pred_track_id = -1 - pred_bbox = [-1,-1,-1,-1] + pred_bbox = [-1, -1, -1, -1] eval_instances.append( Instance( gt_track_id=gt_track_id, diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index 98fe3ef1..39a49b1d 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -1,4 +1,5 @@ """Module containing microscopy dataset.""" + from PIL import Image from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset @@ -93,9 +94,11 @@ def __init__( ] self.frame_idx = [ - torch.arange(Image.open(video).n_frames) - if isinstance(video, str) - else torch.arange(len(video)) + ( + torch.arange(Image.open(video).n_frames) + if isinstance(video, str) + else torch.arange(len(video)) + ) for video in self.videos ] @@ -187,7 +190,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram crop=crop, ) ) - + frames.append( Frame( video_id=label_idx, diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index 8382c8db..73ef5be0 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -1,4 +1,5 @@ """Module containing logic for loading sleap datasets.""" + import albumentations as A import torch import imageio @@ -140,12 +141,19 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict vid_reader = imageio.get_reader(video_name, "ffmpeg") img = vid_reader.get_data(0) - + skeleton = video.skeletons[-1] frames = [] for i, frame_ind in enumerate(frame_idx): - instances, gt_track_ids, poses, shown_poses, point_scores, instance_score = [], [], [], [], [], [] + ( + instances, + gt_track_ids, + poses, + shown_poses, + point_scores, + instance_score, + ) = ([], [], [], [], [], []) frame_ind = int(frame_ind) @@ -181,8 +189,19 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict } for instance in poses ] - - point_scores.append(np.array([point.score if isinstance(point, sio.PredictedPoint) else 1.0 for point in instance.points.values()])) + + point_scores.append( + np.array( + [ + ( + point.score + if isinstance(point, sio.PredictedPoint) + else 1.0 + ) + for point in instance.points.values() + ] + ) + ) if isinstance(instance, sio.PredictedInstance): instance_score.append(instance.score) else: @@ -264,16 +283,16 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict ) crop = data_utils.crop_bbox(img, bbox) - + instance = Instance( gt_track_id=gt_track_ids[j], pred_track_id=-1, crop=crop, bbox=bbox, skeleton=skeleton, - pose = np.array(list(poses[j].values())), - point_scores = point_scores[j], - instance_score = instance_score[j] + pose=np.array(list(poses[j].values())), + point_scores=point_scores[j], + instance_score=instance_score[j], ) instances.append(instance) @@ -281,7 +300,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict frame = Frame( video_id=label_idx, frame_id=frame_ind, - vid_file = video_name, + vid_file=video_name, img_shape=img.shape, instances=instances, ) diff --git a/biogtr/datasets/tracking_dataset.py b/biogtr/datasets/tracking_dataset.py index b80cd636..fdc54cac 100644 --- a/biogtr/datasets/tracking_dataset.py +++ b/biogtr/datasets/tracking_dataset.py @@ -1,4 +1,5 @@ """Module containing Lightning module wrapper around all other datasets.""" + from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset from biogtr.datasets.microscopy_dataset import MicroscopyDataset from biogtr.datasets.sleap_dataset import SleapDataset @@ -74,9 +75,11 @@ def train_dataloader(self) -> DataLoader: pin_memory=False, collate_fn=self.train_ds.no_batching_fn, num_workers=0, - generator=torch.Generator(device="cuda") - if torch.cuda.is_available() - else torch.Generator(), + generator=( + torch.Generator(device="cuda") + if torch.cuda.is_available() + else torch.Generator() + ), ) else: return self.train_dl diff --git a/biogtr/inference/boxes.py b/biogtr/inference/boxes.py index e6ed794f..ec123b18 100644 --- a/biogtr/inference/boxes.py +++ b/biogtr/inference/boxes.py @@ -1,4 +1,5 @@ """Module containing Boxes class.""" + from typing import List, Tuple import torch diff --git a/biogtr/inference/metrics.py b/biogtr/inference/metrics.py index ce931329..a25fa210 100644 --- a/biogtr/inference/metrics.py +++ b/biogtr/inference/metrics.py @@ -1,4 +1,5 @@ """Helper functions for calculating mot metrics.""" + import numpy as np import motmetrics as mm import torch diff --git a/biogtr/inference/post_processing.py b/biogtr/inference/post_processing.py index f15d28eb..8f86f21d 100644 --- a/biogtr/inference/post_processing.py +++ b/biogtr/inference/post_processing.py @@ -1,4 +1,5 @@ """Helper functions for post-processing association matrix pre-tracking.""" + import torch from biogtr.inference.boxes import Boxes from copy import deepcopy diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index a4d33070..9ff7f64e 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -44,6 +44,8 @@ def __init__( max_center_dist: distance threshold for filtering trajectory score matrix. persistent_tracking: whether to keep a buffer across chunks or not. max_gap: the max number of frames a trajectory can be missing before termination. + max_tracks: the maximum number of tracks that can be created while tracking. + We force the tracker to assign instances to a track instead of creating a new track if max_tracks has been reached. verbose: Whether or not to turn on debug printing after each operation. """ self.track_queue = TrackQueue( diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index ebfd5e5d..0743ce43 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -1,4 +1,5 @@ """Module containing GTR model used for training.""" + from biogtr.models.transformer import Transformer from biogtr.models.visual_encoder import VisualEncoder from biogtr.data_structures import Frame diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index 95eb68cc..5413d437 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -1,4 +1,5 @@ """Module containing model helper functions.""" + from copy import deepcopy from typing import List, Tuple, Iterable from pytorch_lightning import loggers diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index 656ab72f..dec1fc3f 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -181,7 +181,7 @@ def forward(self, frames: list[Frame], query_frame: int = None): ).unsqueeze(0) except Exception as e: print([[f.device for f in frame.get_features()] for frame in frames]) - raise(e) + raise (e) window_length = len(frames) instances_per_frame = [frame.num_detected for frame in frames] diff --git a/biogtr/training/losses.py b/biogtr/training/losses.py index e4d725d9..557b78e3 100644 --- a/biogtr/training/losses.py +++ b/biogtr/training/losses.py @@ -1,4 +1,5 @@ """Module containing different loss functions to be optimized.""" + from biogtr.data_structures import Frame from biogtr.models.model_utils import get_boxes_times from torch import nn diff --git a/biogtr/training/train.py b/biogtr/training/train.py index 5c6111dc..56a2d815 100644 --- a/biogtr/training/train.py +++ b/biogtr/training/train.py @@ -2,6 +2,7 @@ Used for training a single model or deploying a batch train job on RUNAI CLI """ + from biogtr.config import Config from biogtr.datasets.tracking_dataset import TrackingDataset from biogtr.datasets.data_utils import view_training_batch @@ -37,7 +38,7 @@ def main(cfg: DictConfig): Args: cfg: The config dict parsed by `hydra` """ - torch.set_float32_matmul_precision('medium') + torch.set_float32_matmul_precision("medium") train_cfg = Config(cfg) # update with parameters for batch train job @@ -80,7 +81,7 @@ def main(cfg: DictConfig): if cfg.view_batch.no_train: return - model = train_cfg.get_gtr_runner() #TODO see if we can use torch.compile() + model = train_cfg.get_gtr_runner() # TODO see if we can use torch.compile() logger = train_cfg.get_logger() diff --git a/biogtr/visualize.py b/biogtr/visualize.py index 08543bb1..bafcf147 100644 --- a/biogtr/visualize.py +++ b/biogtr/visualize.py @@ -1,4 +1,5 @@ """Helper functions for visualizing tracking.""" + from scipy.interpolate import interp1d from copy import deepcopy from tqdm import tqdm @@ -63,14 +64,14 @@ def annotate_video( key: str, color_palette=palette, trails: int = 2, - boxes: int = (64,64), + boxes: int = (64, 64), names: bool = True, track_scores=0.5, centroids: int = 4, poses=False, save_path: str = "debug_animal.mp4", fps: int = 30, - alpha=0.2 + alpha=0.2, ) -> list: """Annotate video frames with labels. @@ -102,7 +103,7 @@ def annotate_video( frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) # else: # frame = frame.copy() - + lf = labels[labels["Frame"] == i] for idx, instance in lf.iterrows(): if not trails: @@ -155,7 +156,7 @@ def annotate_video( # Get coordinates for detected objects in the current frame. if isinstance(boxes, int): boxes = (boxes, boxes) - + box_w, box_h = boxes x = instance["X"] y = instance["Y"] @@ -171,9 +172,9 @@ def annotate_video( # assert idx < len(instance[key]) pred_track_id = instance[key] - + if "Track_score" in instance.index: - track_score = instance['Track_score'] + track_score = instance["Track_score"] else: track_scores = 0 @@ -205,16 +206,26 @@ def annotate_video( # Track trail. if centroids: frame = cv2.circle( - frame, midpt, radius=centroids, color=track_color, thickness=-1 + frame, + midpt, + radius=centroids, + color=track_color, + thickness=-1, ) for i in range(0, len(track_trails[pred_track_id]) - 1): - frame = cv2.addWeighted(cv2.circle( - frame, #.copy(), - track_trails[pred_track_id][i], - radius=4, - color=track_color, - thickness=-1, - ), alpha, frame, 1-alpha, 0) + frame = cv2.addWeighted( + cv2.circle( + frame, # .copy(), + track_trails[pred_track_id][i], + radius=4, + color=track_color, + thickness=-1, + ), + alpha, + frame, + 1 - alpha, + 0, + ) if trails: frame = cv2.line( frame, @@ -226,14 +237,14 @@ def annotate_video( # Track name. name_str = "" - + if names: name_str += f"track_{pred_track_id}" if names and track_scores: name_str += " | " if track_scores: name_str += f"score: {track_score:0.3f}" - + if len(name_str) > 0: frame = cv2.putText( frame, diff --git a/tests/conftest.py b/tests/conftest.py index 434bb560..bf6e6498 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ """Config for pytests.""" + from tests.fixtures.configs import * from tests.fixtures.datasets import * from tests.fixtures.torch import * diff --git a/tests/fixtures/configs.py b/tests/fixtures/configs.py index 8d172d2a..3cf06840 100644 --- a/tests/fixtures/configs.py +++ b/tests/fixtures/configs.py @@ -1,4 +1,5 @@ """Test config paths.""" + import os import pytest diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 572aa094..db574099 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -1,4 +1,5 @@ """Fixtures for testing biogtr.""" + import pytest from pathlib import Path diff --git a/tests/test_data_structures.py b/tests/test_data_structures.py index d44bf9db..31f249b4 100644 --- a/tests/test_data_structures.py +++ b/tests/test_data_structures.py @@ -1,4 +1,5 @@ """Tests for Instance, Frame, and TrackQueue Object""" + from biogtr.data_structures import Instance, Frame from biogtr.inference.track_queue import TrackQueue import torch diff --git a/tests/test_datasets.py b/tests/test_datasets.py index c9950681..0ea6521d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,4 +1,5 @@ """Test dataset logic.""" + from biogtr.datasets.base_dataset import BaseDataset from biogtr.datasets.data_utils import get_max_padding from biogtr.datasets.microscopy_dataset import MicroscopyDataset diff --git a/tests/test_inference.py b/tests/test_inference.py index 30d4f99b..a38a5c96 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,4 +1,5 @@ """Test inference logic.""" + import torch import pytest import numpy as np diff --git a/tests/test_models.py b/tests/test_models.py index ea62a835..6d1abd3e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,5 @@ """Test model modules.""" + import pytest import torch import numpy as np diff --git a/tests/test_training.py b/tests/test_training.py index f12583fa..c04b1eca 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,4 +1,5 @@ """Test training logic.""" + import os import pytest import torch diff --git a/tests/test_version.py b/tests/test_version.py index 3f9e7e0e..6bde7e48 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,4 +1,5 @@ """Test version.""" + import biogtr From e37fc35b5256c22218418877f40e26c2de403439 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 11:12:23 -0700 Subject: [PATCH 32/40] lint + fix docstrings --- biogtr/data_structures.py | 209 +++++++++++++++++++++++--------- biogtr/inference/track_queue.py | 8 +- biogtr/models/gtr_runner.py | 11 +- 3 files changed, 169 insertions(+), 59 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index 9b1dffd2..a711b738 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -1,4 +1,5 @@ """Module containing data classes such as Instances and Frames.""" + import torch import sleap_io as sio import numpy as np @@ -19,7 +20,7 @@ def __init__( features: ArrayLike = torch.tensor([]), track_score: float = -1.0, point_scores: ArrayLike = None, - instance_score:float = -1.0, + instance_score: float = -1.0, skeleton: sio.Skeleton = None, pose: ArrayLike = None, device: str = None, @@ -32,6 +33,11 @@ def __init__( bbox: The bounding box coordinate of the instance. Defaults to an empty tensor. crop: The crop of the instance. features: The reid features extracted from the CNN backbone used in the transformer. + track_score: The track score output from the association matrix. + point_scores: The point scores from sleap. + instance_score: The instance scores from sleap. + skeleton: The sleap skeleton used for the instance. + pose: The pose matrix for the instance containing nodes x 2 points. device: String representation of the device the instance should be on. """ if gt_track_id is not None: @@ -43,7 +49,7 @@ def __init__( self._pred_track_id = torch.tensor([pred_track_id]) else: self._pred_track_id = torch.tensor([]) - + if skeleton is None: self._skeleton = sio.Skeleton(["centroid"]) else: @@ -53,10 +59,10 @@ def __init__( self._bbox = torch.tensor(bbox) else: self._bbox = bbox - + if self._bbox.shape[0] and len(self._bbox.shape) == 1: self._bbox = self._bbox.unsqueeze(0) - + if not isinstance(crop, torch.Tensor): self._crop = torch.tensor(crop) else: @@ -74,20 +80,25 @@ def __init__( if self._features.shape[0] and len(self._features.shape) == 1: self._features = self._features.unsqueeze(0) - + if pose is not None: self._pose = pose - + elif self.bbox.shape[0]: - self._pose = np.array([(self.bbox[:,-1] + self.bbox[:,1])/2,(self.bbox[:,-2] + self.bbox[:,0])/2]) - + self._pose = np.array( + [ + (self.bbox[:, -1] + self.bbox[:, 1]) / 2, + (self.bbox[:, -2] + self.bbox[:, 0]) / 2, + ] + ) + else: - self._pose = np.empty((0,2)) - + self._pose = np.empty((0, 2)) + self._track_score = track_score self._instance_score = instance_score - - if point_scores is not None: + + if point_scores is not None: self._point_scores = point_scores else: self._point_scores = np.zeros_like(self.pose) @@ -124,28 +135,39 @@ def to(self, map_location): self._features = self._features.to(map_location) self.device = map_location return self - - def to_slp(self, track_lookup: dict = {}) -> (sio.PredictedInstance, dict[int, sio.Track]): - """Convert instance to sleap_io.PredictedInstance object - + + def to_slp( + self, track_lookup: dict[int, sio.Track] = {} + ) -> tuple[sio.PredictedInstance, dict[int, sio.Track]]: + """Convert instance to sleap_io.PredictedInstance object. + + Args: + track_lookup: A track look up dictionary containing track_id:sio.Track. Returns: A sleap_io.PredictedInstance with necessary metadata + and a track_lookup dictionary to persist tracks. """ try: track_id = self.pred_track_id.item() if track_id not in track_lookup: track_lookup[track_id] = sio.Track(name=self.pred_track_id.item()) - + track = track_lookup[track_id] - - return sio.PredictedInstance.from_numpy(points=self.pose, - skeleton = self.skeleton, - point_scores=self.point_scores, - instance_score = self.instance_score, - tracking_score = self.track_score, - track = track), track_lookup + + return ( + sio.PredictedInstance.from_numpy( + points=self.pose, + skeleton=self.skeleton, + point_scores=self.point_scores, + instance_score=self.instance_score, + tracking_score=self.track_score, + track=track, + ), + track_lookup, + ) except Exception as e: print(self.pose.shape, self.point_scores.shape) - raise(e) + raise (e) + @property def device(self) -> str: """The device the instance is on. @@ -343,56 +365,111 @@ def has_features(self) -> bool: return False else: return True - @property + + @property def pose(self) -> ArrayLike: + """Get the pose of the instance. + + Returns: + A nodes x 2 array containing the pose coordinates. + """ return self._pose - + @pose.setter def pose(self, pose: ArrayLike) -> None: + """Set the pose of the instance. + + Args: + pose: A nodes x 2 array containing the pose coordinates. + """ self._pose = pose - + def has_pose(self) -> bool: + """Check if the instance has a pose. + + Returns True if the instance has a pose. + """ if self.pose.shape[0]: return True return False - + @property def shown_pose(self) -> ArrayLike: + """Get the pose with shown nodes only. + + Returns: A shown_nodes x 2 pose containing nonnan values from `pose`. + """ pose = self.pose return pose[~np.isnan(pose).any(axis=1)] @property def skeleton(self) -> sio.Skeleton: + """Get the skeleton associated with the instance. + + Returns: The sio.Skeleton associated with the instance. + """ return self._skeleton - + @skeleton.setter def skeleton(self, skeleton: sio.Skeleton) -> None: + """Set the skeleton associated with the instance. + + Args: + skeleton: The sio.Skeleton associated with the instance. + """ self._skeleton = skeleton - + @property def point_scores(self) -> ArrayLike: + """Get the point scores associated with the pose prediction. + + Returns: a vector of shape n containing the point scores outputed from sleap associated with pose predictions. + """ return self._point_scores - + @point_scores.setter def point_scores(self, point_scores: ArrayLike) -> None: + """Set the point scores associated with the pose prediction. + + Args: + point_scores: a vector of shape n containing the point scores + outputted from sleap associated with pose predictions. + """ self._point_scores = point_scores - + @property def instance_score(self) -> float: + """Get the pose prediction score associated with the instance. + + Returns: a float from 0-1 representing an instance_score. + """ return self._instance_score - + @instance_score.setter def instance_score(self, instance_score: float) -> None: + """Set the pose prediction score associated with the instance. + + Args: + instance_score: a float from 0-1 representing an instance_score. + """ self._instance_score = instance_score - + @property def track_score(self) -> float: + """Get the track_score of the instance. + + Returns: A float from 0-1 representing the output used in the tracker for assignment. + """ return self._track_score - + @track_score.setter def track_score(self, track_score: float) -> None: + """Set the track_score of the instance. + + Args: + track_score: A float from 0-1 representing the output used in the tracker for assignment. + """ self._track_score = track_score - class Frame: @@ -429,11 +506,11 @@ def __init__( """ self._video_id = torch.tensor([video_id]) self._frame_id = torch.tensor([frame_id]) - + try: self._video = sio.Video(vid_file) except ValueError as e: - #warnings.warn(f"{e}") + # warnings.warn(f"{e}") self._video = vid_file if isinstance(img_shape, torch.Tensor): @@ -505,20 +582,31 @@ def to(self, map_location: str): self._device = map_location return self - - def to_slp(self, track_lookup = {}) -> (sio.LabeledFrame, dict[int, sio.Track]): + + def to_slp( + self, track_lookup: dict[int : sio.Track] = {} + ) -> tuple[sio.LabeledFrame, dict[int, sio.Track]]: """Convert Frame to sleap_io.LabeledFrame object. - - Returns: A LabeledFrame object with necessary metadata. + + Args: + track_lookup: A lookup dictionary containing the track_id and sio.Track for persistence + + Returns: A tuple containing a LabeledFrame object with necessary metadata and + a lookup dictionary containing the track_id and sio.Track for persistence """ slp_instances = [] for instance in self.instances: slp_instance, track_lookup = instance.to_slp(track_lookup=track_lookup) slp_instances.append(slp_instance) - return sio.LabeledFrame(video = self.video, - frame_idx = self.frame_id.item(), - instances = slp_instances), track_lookup - + return ( + sio.LabeledFrame( + video=self.video, + frame_idx=self.frame_id.item(), + instances=slp_instances, + ), + track_lookup, + ) + @property def device(self) -> str: """The device the frame is on. @@ -567,7 +655,7 @@ def frame_id(self) -> torch.Tensor: A torch tensor containing the index of the frame in the video. """ return self._frame_id - + @frame_id.setter def frame_id(self, frame_id: int) -> None: """Set the frame index of the frame. @@ -578,19 +666,33 @@ def frame_id(self, frame_id: int) -> None: frame_id: The int index of the frame in the full video. """ self._frame_id = torch.tensor([frame_id]) - + @property def video(self) -> Union[sio.Video, str]: + """Get the video associated with the frame. + + Returns: An sio.Video object representing the video or a placeholder string + if it is not possible to create the sio.Video + """ return self._video - + @video.setter def video(self, video_filename: str) -> None: + """Set the video associated with the frame. + + Note: we try to store the video in an sio.Video object. + However, if this is not possible (e.g. incompatible format or missing filepath) + then we simply store the string. + + Args: + video_filename: string path to video_file + """ try: self._video = video_filename except ValueError as e: - #warnings.warn(f"{e}") + # warnings.warn(f"{e}") self._video = video_filename - + @property def img_shape(self) -> torch.Tensor: """The shape of the pre-cropped frame. @@ -829,7 +931,7 @@ def get_crops(self) -> torch.Tensor: return torch.cat([instance.crop for instance in self.instances], dim=0) except Exception as e: print(self) - raise(e) + raise (e) def has_features(self): """Check if any of frames instances has reid features already computed. @@ -850,4 +952,3 @@ def get_features(self): if not self.has_instances(): return torch.tensor([]) return torch.cat([instance.features for instance in self.instances], dim=0) - diff --git a/biogtr/inference/track_queue.py b/biogtr/inference/track_queue.py index ef616639..43a4353c 100644 --- a/biogtr/inference/track_queue.py +++ b/biogtr/inference/track_queue.py @@ -194,7 +194,7 @@ def add_frame(self, frame: Frame) -> None: vid_name = frame.video else: vid_name = frame.video.filename - #traj_score = frame.get_traj_score() TODO: figure out better way to save trajectory scores. + # traj_score = frame.get_traj_score() TODO: figure out better way to save trajectory scores. frame_meta = (vid_id, frame_id, vid_name, img_shape.cpu().tolist()) pred_tracks = [] @@ -248,7 +248,11 @@ def collate_tracks( for video_id, frame_id, vid_name, img_shape, instance in instances: if (video_id, frame_id) not in frames.keys(): frame = Frame( - video_id, frame_id, img_shape=img_shape, instances=[instance], vid_file=vid_name + video_id, + frame_id, + img_shape=img_shape, + instances=[instance], + vid_file=vid_name, ) frames[(video_id, frame_id)] = frame else: diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index 34c8f3fb..b703915e 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -1,4 +1,5 @@ """Module containing training, validation and inference logic.""" + import torch import gc from biogtr.inference.tracker import Tracker @@ -170,13 +171,13 @@ def _shared_eval_step(self, instances, mode): instances_mm = metrics.to_track_eval(instances_pred) clearmot = metrics.get_pymotmetrics(instances_mm, eval_metrics) return_metrics.update(clearmot.to_dict()) - return_metrics['batch_size'] = len(instances) + return_metrics["batch_size"] = len(instances) except Exception as e: print( f"Failed on frame {instances[0].frame_id} of video {instances[0].video_id}" ) raise (e) - + return return_metrics def configure_optimizers(self) -> dict: @@ -223,7 +224,11 @@ def log_metrics(self, result: dict, mode: str) -> None: if isinstance(val, torch.TensorType): val = val.item() self.log(f"{mode}_{metric}", val, batch_size=batch_size) - + def on_validation_epoch_end(self): + """Execute hook for validation end. + + Currently, we simply clear the gpu cache and do garbage collection. + """ gc.collect() torch.cuda.empty_cache() From 4c61fbe224a9e4e59211e7bd8b40c950606bc845 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 13:04:13 -0700 Subject: [PATCH 33/40] fix typing in cell tracking dataset, save pose as dictionary fix tests --- biogtr/data_structures.py | 34 +++++++++++++----------- biogtr/datasets/cell_tracking_dataset.py | 14 ++++++---- biogtr/models/model_utils.py | 24 ++++++++++------- tests/test_datasets.py | 6 ++++- tests/test_training.py | 2 +- 5 files changed, 48 insertions(+), 32 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index a711b738..3c9ca117 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -22,7 +22,7 @@ def __init__( point_scores: ArrayLike = None, instance_score: float = -1.0, skeleton: sio.Skeleton = None, - pose: ArrayLike = None, + pose: dict[str, ArrayLike] = None, device: str = None, ): """Initialize Instance. @@ -37,7 +37,7 @@ def __init__( point_scores: The point scores from sleap. instance_score: The instance scores from sleap. skeleton: The sleap skeleton used for the instance. - pose: The pose matrix for the instance containing nodes x 2 points. + pose: A dictionary containing the node name and corresponding point. device: String representation of the device the instance should be on. """ if gt_track_id is not None: @@ -85,15 +85,17 @@ def __init__( self._pose = pose elif self.bbox.shape[0]: - self._pose = np.array( - [ - (self.bbox[:, -1] + self.bbox[:, 1]) / 2, - (self.bbox[:, -2] + self.bbox[:, 0]) / 2, - ] - ) + self._pose = { + "centroid": np.array( + [ + (self.bbox[:, -1] + self.bbox[:, 1]) / 2, + (self.bbox[:, -2] + self.bbox[:, 0]) / 2, + ] + ) + } else: - self._pose = np.empty((0, 2)) + self._pose = {} self._track_score = track_score self._instance_score = instance_score @@ -367,16 +369,16 @@ def has_features(self) -> bool: return True @property - def pose(self) -> ArrayLike: + def pose(self) -> dict[str, ArrayLike]: """Get the pose of the instance. Returns: - A nodes x 2 array containing the pose coordinates. + A dictionary containing the node and corresponding x,y points """ return self._pose @pose.setter - def pose(self, pose: ArrayLike) -> None: + def pose(self, pose: dict[str, ArrayLike]) -> None: """Set the pose of the instance. Args: @@ -389,18 +391,18 @@ def has_pose(self) -> bool: Returns True if the instance has a pose. """ - if self.pose.shape[0]: + if len(self.pose): return True return False @property - def shown_pose(self) -> ArrayLike: + def shown_pose(self) -> dict[str, ArrayLike]: """Get the pose with shown nodes only. - Returns: A shown_nodes x 2 pose containing nonnan values from `pose`. + Returns: A dictionary filtered by nodes that are shown (points are not nan). """ pose = self.pose - return pose[~np.isnan(pose).any(axis=1)] + return {node: point for node, point in pose.items() if not np.isna(point).any()} @property def skeleton(self) -> sio.Skeleton: diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 8ca0c17b..4b784fd4 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -18,8 +18,8 @@ class CellTrackingDataset(BaseDataset): def __init__( self, - raw_images: list[str], - gt_images: list[str], + raw_images: list[list[str]], + gt_images: list[list[str]], padding: int = 5, crop_size: int = 20, chunk: bool = False, @@ -28,7 +28,7 @@ def __init__( augmentations: Optional[dict] = None, n_chunks: Union[int, float] = 1.0, seed: int = None, - gt_list: str = None, + gt_list: list[str] = None, ): """Initialize CellTrackingDataset. @@ -125,7 +125,11 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram """ image = self.videos[label_idx] gt = self.labels[label_idx] - gt_list = self.gt_list[label_idx] + + if self.gt_list is not None: + gt_list = self.gt_list[label_idx] + else: + gt_list = None frames = [] @@ -145,7 +149,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram np.uint8 ) - if self.gt_list is None: + if gt_list is None: unique_instances = np.unique(gt_sec) else: unique_instances = gt_list["track_id"].unique() diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index 5413d437..2e32cb22 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -1,5 +1,4 @@ """Module containing model helper functions.""" - from copy import deepcopy from typing import List, Tuple, Iterable from pytorch_lightning import loggers @@ -27,7 +26,7 @@ def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: bbox[:, [1, 3]] /= h boxes.append(bbox) - times.append(torch.full((bbox.shape[0],), fidx)) + times.append(torch.full((bbox.shape[0],), fidx)) #- len(frames) //2 boxes = torch.cat(boxes, dim=0) # N x 4 times = torch.cat(times, dim=0).to(boxes.device) # N @@ -134,18 +133,19 @@ def init_scheduler(optimizer: torch.optim.Optimizer, config: dict): return scheduler_class(optimizer, **scheduler_params) -def init_logger(config: dict): +def init_logger(logger_params: dict, config: dict=None): """Initialize logger based on config parameters. Allows more flexibility in choosing which logger to use. Args: - config: logger hyperparameters + logger_params: logger hyperparameters + config: rest of hyperparameters to log (mostly used for WandB) Returns: logger: A logger with specified params (or None). """ - logger_type = config.pop("logger_type", None) + logger_type = logger_params.pop("logger_type", None) valid_loggers = [ "CSVLogger", @@ -155,10 +155,16 @@ def init_logger(config: dict): if logger_type in valid_loggers: logger_class = getattr(loggers, logger_type) - try: - return logger_class(**config) - except Exception as e: - print(e, logger_type) + if logger_class == loggers.WandbLogger: + try: + return logger_class(config=config, **logger_params) + except Exception as e: + print(e, logger_type) + else: + try: + return logger_class(**logger_params) + except Exception as e: + print(e, logger_type) else: print( f"{logger_type} not one of {valid_loggers} or set to None, skipping logging" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 0ea6521d..97882a80 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -194,13 +194,17 @@ def test_cell_tracking_dataset(cell_tracking): clip_length = 8 + # print(cell_tracking[0]) + # print(cell_tracking[1]) + # print(cell_tracking[2]) + train_ds = CellTrackingDataset( raw_images=[cell_tracking[0]], gt_images=[cell_tracking[1]], crop_size=128, chunk=True, clip_length=clip_length, - gt_list=cell_tracking[2], + gt_list=[cell_tracking[2]], ) instances = next(iter(train_ds)) diff --git a/tests/test_training.py b/tests/test_training.py index c04b1eca..9120af48 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -102,7 +102,7 @@ def test_basic_gtr_runner(): gtr_runner.train() assert gtr_runner.model.training metrics = gtr_runner.training_step([batch], i) - assert "loss" in metrics and "num_switches" in metrics + assert "loss" in metrics assert metrics["loss"].requires_grad for j, batch in enumerate(train_ds): From 52d9ae8f3dd68574036f43a14a3c68b145c18531 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 13:19:24 -0700 Subject: [PATCH 34/40] lint --- biogtr/models/model_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index 2e32cb22..adf1a662 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -1,4 +1,5 @@ """Module containing model helper functions.""" + from copy import deepcopy from typing import List, Tuple, Iterable from pytorch_lightning import loggers @@ -26,7 +27,7 @@ def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: bbox[:, [1, 3]] /= h boxes.append(bbox) - times.append(torch.full((bbox.shape[0],), fidx)) #- len(frames) //2 + times.append(torch.full((bbox.shape[0],), fidx)) boxes = torch.cat(boxes, dim=0) # N x 4 times = torch.cat(times, dim=0).to(boxes.device) # N @@ -133,7 +134,7 @@ def init_scheduler(optimizer: torch.optim.Optimizer, config: dict): return scheduler_class(optimizer, **scheduler_params) -def init_logger(logger_params: dict, config: dict=None): +def init_logger(logger_params: dict, config: dict = None): """Initialize logger based on config parameters. Allows more flexibility in choosing which logger to use. From 8a560c89006d05128619ef56c592101932448479 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 13:46:29 -0700 Subject: [PATCH 35/40] add open-cv-headless to environment.yml --- biogtr/data_structures.py | 16 +++++++++++++++- environment.yml | 5 +++-- environment_cpu.yml | 1 + 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index 3c9ca117..edcd7a7e 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -384,7 +384,21 @@ def pose(self, pose: dict[str, ArrayLike]) -> None: Args: pose: A nodes x 2 array containing the pose coordinates. """ - self._pose = pose + if pose is not None: + self._pose = pose + + elif self.bbox.shape[0]: + self._pose = { + "centroid": np.array( + [ + (self.bbox[:, -1] + self.bbox[:, 1]) / 2, + (self.bbox[:, -2] + self.bbox[:, 0]) / 2, + ] + ) + } + + else: + self._pose = {} def has_pose(self) -> bool: """Check if the instance has a pose. diff --git a/environment.yml b/environment.yml index 85696a79..310e66b6 100644 --- a/environment.yml +++ b/environment.yml @@ -8,8 +8,8 @@ channels: dependencies: - python=3.9 - - pytorch-cuda=11.8 - - cudatoolkit=11.8 + - pytorch-cuda=12.1 + #- cudatoolkit=12.1 - cudnn - pytorch - torchvision @@ -20,6 +20,7 @@ dependencies: - albumentations - pip - pip: + - opencv-python-headless - matplotlib - sleap-io - "--editable=.[dev]" diff --git a/environment_cpu.yml b/environment_cpu.yml index 1b2a6841..2eaf73f7 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -18,6 +18,7 @@ dependencies: - albumentations - pip - pip: + - opencv-python-headless - matplotlib - sleap-io - "--editable=.[dev]" From 96048aa0fc3c8ec9918cc0dcd4bc7a9f4391c93f Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 14:13:56 -0700 Subject: [PATCH 36/40] use conda-forge opencv instead of opencv-headless --- biogtr/data_structures.py | 21 +++++---------------- environment.yml | 3 +-- environment_cpu.yml | 2 +- 3 files changed, 7 insertions(+), 19 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index edcd7a7e..3d241f81 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -85,14 +85,9 @@ def __init__( self._pose = pose elif self.bbox.shape[0]: - self._pose = { - "centroid": np.array( - [ - (self.bbox[:, -1] + self.bbox[:, 1]) / 2, - (self.bbox[:, -2] + self.bbox[:, 0]) / 2, - ] - ) - } + + y1, x1, y2, x2 = self.bbox.squeeze() + self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])} else: self._pose = {} @@ -388,14 +383,8 @@ def pose(self, pose: dict[str, ArrayLike]) -> None: self._pose = pose elif self.bbox.shape[0]: - self._pose = { - "centroid": np.array( - [ - (self.bbox[:, -1] + self.bbox[:, 1]) / 2, - (self.bbox[:, -2] + self.bbox[:, 0]) / 2, - ] - ) - } + y1, x1, y2, x2 = self.bbox.squeeze() + self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])} else: self._pose = {} diff --git a/environment.yml b/environment.yml index 310e66b6..3637a7e2 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ channels: dependencies: - python=3.9 - pytorch-cuda=12.1 - #- cudatoolkit=12.1 + - conda-forge::opencv <4.9.0 - cudnn - pytorch - torchvision @@ -20,7 +20,6 @@ dependencies: - albumentations - pip - pip: - - opencv-python-headless - matplotlib - sleap-io - "--editable=.[dev]" diff --git a/environment_cpu.yml b/environment_cpu.yml index 2eaf73f7..8e22da37 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -8,6 +8,7 @@ channels: dependencies: - python=3.9 + - conda-forge::opencv <4.9.0 - pytorch - cpuonly - torchvision @@ -18,7 +19,6 @@ dependencies: - albumentations - pip - pip: - - opencv-python-headless - matplotlib - sleap-io - "--editable=.[dev]" From d7d2d095c2b315f66b9f58075585dcf6e8017713 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Fri, 19 Apr 2024 15:41:02 -0700 Subject: [PATCH 37/40] * commit suggestions from * @coderabbit's reviews check box coords to make sure x1x2 or y1y2 are not negative save embeddings to model for gradient update --- biogtr/data_structures.py | 15 +++++++-------- biogtr/datasets/data_utils.py | 14 ++++++-------- biogtr/inference/tracker.py | 1 - biogtr/models/embedding.py | 15 ++++++++++++--- biogtr/models/gtr_runner.py | 2 +- tests/test_models.py | 3 ++- 6 files changed, 28 insertions(+), 22 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index 3d241f81..a6ce32da 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -3,7 +3,6 @@ import torch import sleap_io as sio import numpy as np -import warnings from numpy.typing import ArrayLike from typing import Union, List @@ -22,7 +21,7 @@ def __init__( point_scores: ArrayLike = None, instance_score: float = -1.0, skeleton: sio.Skeleton = None, - pose: dict[str, ArrayLike] = None, + pose: dict[str, ArrayLike] = np.array([]), device: str = None, ): """Initialize Instance. @@ -162,8 +161,10 @@ def to_slp( track_lookup, ) except Exception as e: - print(self.pose.shape, self.point_scores.shape) - raise (e) + print( + f"Pose shape: {self.pose.shape}, Pose score shape {self.point_scores.shape}" + ) + raise RuntimeError(f"Failed to convert to sio.PredictedInstance: {e}") @property def device(self) -> str: @@ -514,8 +515,7 @@ def __init__( try: self._video = sio.Video(vid_file) - except ValueError as e: - # warnings.warn(f"{e}") + except ValueError: self._video = vid_file if isinstance(img_shape, torch.Tensor): @@ -694,8 +694,7 @@ def video(self, video_filename: str) -> None: """ try: self._video = video_filename - except ValueError as e: - # warnings.warn(f"{e}") + except ValueError: self._video = video_filename @property diff --git a/biogtr/datasets/data_utils.py b/biogtr/datasets/data_utils.py index d482fc23..b900caf1 100644 --- a/biogtr/datasets/data_utils.py +++ b/biogtr/datasets/data_utils.py @@ -67,9 +67,11 @@ def get_bbox(center: ArrayLike, size: Union[int, tuple[int]]) -> torch.Tensor: size = (size, size) cx, cy = center[0], center[1] - bbox = torch.Tensor( - [-size[-1] // 2 + cy, -size[0] // 2 + cx, size[-1] // 2 + cy, size[0] // 2 + cx] - ) + y1 = max(0, -size[-1] // 2 + cy) + x1 = max(0, -size[0] // 2 + cx) + y2 = size[-1] // 2 + cy if y1 != 0 else size[1] + x2 = size[0] // 2 + cx if x1 != 0 else size[0] + bbox = torch.Tensor([y1, x1, y2, x2]) return bbox @@ -476,11 +478,7 @@ def view_training_batch( else (axes[i] if num_crops == 1 else axes[i, j]) ) - ( - ax.imshow(data.T) - if isinstance(cmap, None) - else ax.imshow(data.T, cmap=cmap) - ) + (ax.imshow(data.T) if cmap is None else ax.imshow(data.T, cmap=cmap)) ax.axis("off") except Exception as e: diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 9ff7f64e..44674aad 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -10,7 +10,6 @@ from biogtr.inference import post_processing from biogtr.inference.boxes import Boxes from scipy.optimize import linear_sum_assignment -from copy import deepcopy from math import inf diff --git a/biogtr/models/embedding.py b/biogtr/models/embedding.py index 819ec533..364d4c8f 100644 --- a/biogtr/models/embedding.py +++ b/biogtr/models/embedding.py @@ -17,7 +17,8 @@ def __init__(self): """Initialize embeddings.""" super().__init__() # empty init for flexibility - pass + self.pos_lookup = None + self.temp_lookup = None def _torch_int_div( self, tensor1: torch.Tensor, tensor2: torch.Tensor @@ -126,7 +127,11 @@ def _learned_pos_embedding( self.learn_pos_emb_num = params["learn_pos_emb_num"] self.over_boxes = params["over_boxes"] - pos_lookup = torch.nn.Embedding(self.learn_pos_emb_num * 4, self.features // 4) + if self.pos_lookup is None: + self.pos_lookup = torch.nn.Embedding( + self.learn_pos_emb_num * 4, self.features // 4 + ) + pos_lookup = self.pos_lookup N = boxes.shape[0] boxes = boxes.view(N, 4) @@ -187,8 +192,12 @@ def _learned_temp_embedding( self.features = params["features"] self.learn_temp_emb_num = params["learn_temp_emb_num"] - temp_lookup = torch.nn.Embedding(self.learn_temp_emb_num, self.features) + if self.temp_lookup is None: + self.temp_lookup = torch.nn.Embedding( + self.learn_temp_emb_num, self.features + ) + temp_lookup = self.temp_lookup N = times.shape[0] l, r, lw, rw = self._compute_weights(times, self.learn_temp_emb_num) diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index b703915e..b0b75301 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -221,7 +221,7 @@ def log_metrics(self, result: dict, mode: str) -> None: if result: batch_size = result.pop("batch_size") for metric, val in result.items(): - if isinstance(val, torch.TensorType): + if isinstance(val, torch.Tensor): val = val.item() self.log(f"{mode}_{metric}", val, batch_size=batch_size) diff --git a/tests/test_models.py b/tests/test_models.py index 6d1abd3e..ceae0bc5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,7 +2,6 @@ import pytest import torch -import numpy as np from biogtr.data_structures import Frame, Instance from biogtr.models.attention_head import MLP, ATTWeightHead from biogtr.models.embedding import Embedding @@ -124,6 +123,7 @@ def test_embedding_kwargs(): lp_args = {"learn_pos_emb_num": 100, "over_boxes": False} + emb = Embedding() lp_with_args = emb._learned_pos_embedding(boxes, **lp_args) assert not torch.equal(lp_no_args, lp_with_args) @@ -134,6 +134,7 @@ def test_embedding_kwargs(): lt_args = {"learn_temp_emb_num": 100} + emb = Embedding() lt_with_args = emb._learned_temp_embedding(times, **lt_args) assert not torch.equal(lt_no_args, lt_with_args) From 2eef70f300b12aab5b533cd94eba4881dc844063 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 22 Apr 2024 11:21:17 -0700 Subject: [PATCH 38/40] add @coderabbit's suggestions --- biogtr/datasets/data_utils.py | 11 +++++++---- biogtr/inference/metrics.py | 12 ++++++------ biogtr/inference/post_processing.py | 4 ++-- biogtr/models/model_utils.py | 3 ++- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/biogtr/datasets/data_utils.py b/biogtr/datasets/data_utils.py index 76b44972..027937e3 100644 --- a/biogtr/datasets/data_utils.py +++ b/biogtr/datasets/data_utils.py @@ -1,4 +1,5 @@ """Module containing helper functions for datasets.""" + from PIL import Image from numpy.typing import ArrayLike from torchvision.transforms import functional as tvf @@ -62,7 +63,7 @@ def get_bbox(center: ArrayLike, size: Union[int, tuple[int]]) -> torch.Tensor: Returns: A torch tensor in form y1, x1, y2, x2 """ - if type(size) == int: + if isinstance(size, int): size = (size, size) cx, cy = center[0], center[1] @@ -116,7 +117,7 @@ def pose_bbox(points: np.ndarray, bbox_size: Union[tuple[int], int]) -> torch.Te Returns: Bounding box in [y1, x1, y2, x2] format. """ - if type(bbox_size) == int: + if isinstance(bbox_size, int): bbox_size = (bbox_size, bbox_size) # print(points) minx = np.nanmin(points[:, 0], axis=-1) @@ -475,8 +476,10 @@ def view_training_batch( else (axes[i] if num_crops == 1 else axes[i, j]) ) - ax.imshow(data.T) if isinstance(cmap, None) else ax.imshow( - data.T, cmap=cmap + ( + ax.imshow(data.T) + if isinstance(cmap, None) + else ax.imshow(data.T, cmap=cmap) ) ax.axis("off") diff --git a/biogtr/inference/metrics.py b/biogtr/inference/metrics.py index a0ccebfe..00c35940 100644 --- a/biogtr/inference/metrics.py +++ b/biogtr/inference/metrics.py @@ -1,9 +1,9 @@ """Helper functions for calculating mot metrics.""" + import numpy as np import motmetrics as mm import torch from biogtr.data_structures import Frame -import warnings from typing import Union, Iterable # from biogtr.inference.post_processing import _pairwise_iou @@ -14,7 +14,7 @@ def get_matches(frames: list[Frame]) -> tuple[dict, list, int]: """Get comparison between predicted and gt trajectory labels. Args: - instances: a list of Frames containing the video_id, frame_id, + frames: a list of Frames containing the video_id, frame_id, gt labels and predicted labels Returns: @@ -67,9 +67,9 @@ def get_switches(matches: dict, indices: list) -> dict: switches[idx] = {} col = matches[:, i] - indices = np.where(col == 1)[0] + match_indices = np.where(col == 1)[0] match_i = [ - (m.split(" ")[0], m.split(" ")[-1]) for m in matches_key[indices] + (m.split(" ")[0], m.split(" ")[-1]) for m in matches_key[match_indices] ] for m in match_i: @@ -105,7 +105,7 @@ def to_track_eval(frames: list[Frame]) -> dict: """Reformats frames the output from `sliding_inference` to be used by `TrackEval`. Args: - instances: A list of Frames. `See biogtr.data_structures for more info`. + frames: A list of Frames. `See biogtr.data_structures for more info`. Returns: data: A dictionary. Example provided below. @@ -123,7 +123,7 @@ def to_track_eval(frames: list[Frame]) -> dict: "gt_ids": (L, *), # Ragged np.array "tracker_ids": (L, ^), # Ragged np.array "similarity_scores": (L, *, ^), # Ragged np.array - "num_timsteps": L, + "num_timesteps": L, } """ unique_gt_ids = [] diff --git a/biogtr/inference/post_processing.py b/biogtr/inference/post_processing.py index f15d28eb..26837150 100644 --- a/biogtr/inference/post_processing.py +++ b/biogtr/inference/post_processing.py @@ -1,7 +1,7 @@ """Helper functions for post-processing association matrix pre-tracking.""" + import torch from biogtr.inference.boxes import Boxes -from copy import deepcopy def weight_decay_time( @@ -156,7 +156,7 @@ def filter_max_center_dist( .long() .bool() ) # n_k x M - asso_output_filtered = deepcopy(asso_output) + asso_output_filtered = asso_output.clone() asso_output_filtered[~valid_assn] = 0 # n_k x M return asso_output_filtered else: diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index 95eb68cc..ddb95c7d 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -1,4 +1,5 @@ """Module containing model helper functions.""" + from copy import deepcopy from typing import List, Tuple, Iterable from pytorch_lightning import loggers @@ -10,7 +11,7 @@ def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: """Extract the bounding boxes and frame indices from the input list of instances. Args: - instances (List[Dict]): List of instance dictionaries + frames (List[Frame]): List of frame objects containing metadata and instances. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors containing the From c351008e74a5b27a239e31ac665ba9548d46f870 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 22 Apr 2024 11:34:44 -0700 Subject: [PATCH 39/40] lint --- biogtr/data_structures.py | 1 + biogtr/datasets/base_dataset.py | 1 + biogtr/datasets/cell_tracking_dataset.py | 1 + biogtr/datasets/eval_dataset.py | 1 + biogtr/datasets/microscopy_dataset.py | 9 ++++++--- biogtr/datasets/sleap_dataset.py | 1 + biogtr/datasets/tracking_dataset.py | 9 ++++++--- biogtr/inference/boxes.py | 1 + biogtr/inference/tracker.py | 1 + biogtr/models/global_tracking_transformer.py | 1 + biogtr/training/losses.py | 1 + biogtr/training/train.py | 1 + biogtr/visualize.py | 1 + tests/conftest.py | 1 + tests/fixtures/configs.py | 1 + tests/fixtures/datasets.py | 1 + tests/test_data_structures.py | 1 + tests/test_datasets.py | 1 + tests/test_inference.py | 1 + tests/test_models.py | 1 + tests/test_training.py | 1 + tests/test_version.py | 1 + 22 files changed, 32 insertions(+), 6 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index d0efac4e..813f8867 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -1,4 +1,5 @@ """Module containing data classes such as Instances and Frames.""" + import torch from numpy.typing import ArrayLike from typing import Union, List diff --git a/biogtr/datasets/base_dataset.py b/biogtr/datasets/base_dataset.py index 61b3e00b..585e176d 100644 --- a/biogtr/datasets/base_dataset.py +++ b/biogtr/datasets/base_dataset.py @@ -1,4 +1,5 @@ """Module containing logic for loading datasets.""" + from biogtr.datasets import data_utils from biogtr.data_structures import Frame from torch.utils.data import Dataset diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 6a421e13..a4d0d056 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -1,4 +1,5 @@ """Module containing cell tracking challenge dataset.""" + from PIL import Image from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset diff --git a/biogtr/datasets/eval_dataset.py b/biogtr/datasets/eval_dataset.py index f3142b16..026d5641 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/biogtr/datasets/eval_dataset.py @@ -1,4 +1,5 @@ """Module containing wrapper for merging gt and pred datasets for evaluation.""" + from torch.utils.data import Dataset from biogtr.data_structures import Frame, Instance from typing import List diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index 05d8aa73..0b9f3792 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -1,4 +1,5 @@ """Module containing microscopy dataset.""" + from PIL import Image from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset @@ -93,9 +94,11 @@ def __init__( ] self.frame_idx = [ - torch.arange(Image.open(video).n_frames) - if isinstance(video, str) - else torch.arange(len(video)) + ( + torch.arange(Image.open(video).n_frames) + if isinstance(video, str) + else torch.arange(len(video)) + ) for video in self.videos ] diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index 768e8213..ed814804 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -1,4 +1,5 @@ """Module containing logic for loading sleap datasets.""" + import albumentations as A import torch import imageio diff --git a/biogtr/datasets/tracking_dataset.py b/biogtr/datasets/tracking_dataset.py index b80cd636..fdc54cac 100644 --- a/biogtr/datasets/tracking_dataset.py +++ b/biogtr/datasets/tracking_dataset.py @@ -1,4 +1,5 @@ """Module containing Lightning module wrapper around all other datasets.""" + from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset from biogtr.datasets.microscopy_dataset import MicroscopyDataset from biogtr.datasets.sleap_dataset import SleapDataset @@ -74,9 +75,11 @@ def train_dataloader(self) -> DataLoader: pin_memory=False, collate_fn=self.train_ds.no_batching_fn, num_workers=0, - generator=torch.Generator(device="cuda") - if torch.cuda.is_available() - else torch.Generator(), + generator=( + torch.Generator(device="cuda") + if torch.cuda.is_available() + else torch.Generator() + ), ) else: return self.train_dl diff --git a/biogtr/inference/boxes.py b/biogtr/inference/boxes.py index e6ed794f..ec123b18 100644 --- a/biogtr/inference/boxes.py +++ b/biogtr/inference/boxes.py @@ -1,4 +1,5 @@ """Module containing Boxes class.""" + from typing import List, Tuple import torch diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 506bbc48..a389eb1e 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -1,4 +1,5 @@ """Module containing logic for going from association -> assignment.""" + import torch import pandas as pd import warnings diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index ebfd5e5d..0743ce43 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -1,4 +1,5 @@ """Module containing GTR model used for training.""" + from biogtr.models.transformer import Transformer from biogtr.models.visual_encoder import VisualEncoder from biogtr.data_structures import Frame diff --git a/biogtr/training/losses.py b/biogtr/training/losses.py index e4d725d9..557b78e3 100644 --- a/biogtr/training/losses.py +++ b/biogtr/training/losses.py @@ -1,4 +1,5 @@ """Module containing different loss functions to be optimized.""" + from biogtr.data_structures import Frame from biogtr.models.model_utils import get_boxes_times from torch import nn diff --git a/biogtr/training/train.py b/biogtr/training/train.py index 65169d27..049cacf1 100644 --- a/biogtr/training/train.py +++ b/biogtr/training/train.py @@ -2,6 +2,7 @@ Used for training a single model or deploying a batch train job on RUNAI CLI """ + from biogtr.config import Config from biogtr.datasets.tracking_dataset import TrackingDataset from biogtr.datasets.data_utils import view_training_batch diff --git a/biogtr/visualize.py b/biogtr/visualize.py index 6d404d44..bf7fa5d6 100644 --- a/biogtr/visualize.py +++ b/biogtr/visualize.py @@ -1,4 +1,5 @@ """Helper functions for visualizing tracking.""" + from scipy.interpolate import interp1d from copy import deepcopy from tqdm import tqdm diff --git a/tests/conftest.py b/tests/conftest.py index 434bb560..bf6e6498 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ """Config for pytests.""" + from tests.fixtures.configs import * from tests.fixtures.datasets import * from tests.fixtures.torch import * diff --git a/tests/fixtures/configs.py b/tests/fixtures/configs.py index 8d172d2a..3cf06840 100644 --- a/tests/fixtures/configs.py +++ b/tests/fixtures/configs.py @@ -1,4 +1,5 @@ """Test config paths.""" + import os import pytest diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 572aa094..db574099 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -1,4 +1,5 @@ """Fixtures for testing biogtr.""" + import pytest from pathlib import Path diff --git a/tests/test_data_structures.py b/tests/test_data_structures.py index d44bf9db..31f249b4 100644 --- a/tests/test_data_structures.py +++ b/tests/test_data_structures.py @@ -1,4 +1,5 @@ """Tests for Instance, Frame, and TrackQueue Object""" + from biogtr.data_structures import Instance, Frame from biogtr.inference.track_queue import TrackQueue import torch diff --git a/tests/test_datasets.py b/tests/test_datasets.py index c9950681..0ea6521d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,4 +1,5 @@ """Test dataset logic.""" + from biogtr.datasets.base_dataset import BaseDataset from biogtr.datasets.data_utils import get_max_padding from biogtr.datasets.microscopy_dataset import MicroscopyDataset diff --git a/tests/test_inference.py b/tests/test_inference.py index 30d4f99b..a38a5c96 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,4 +1,5 @@ """Test inference logic.""" + import torch import pytest import numpy as np diff --git a/tests/test_models.py b/tests/test_models.py index ea62a835..6d1abd3e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,5 @@ """Test model modules.""" + import pytest import torch import numpy as np diff --git a/tests/test_training.py b/tests/test_training.py index f12583fa..c04b1eca 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,4 +1,5 @@ """Test training logic.""" + import os import pytest import torch diff --git a/tests/test_version.py b/tests/test_version.py index 3f9e7e0e..6bde7e48 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,4 +1,5 @@ """Test version.""" + import biogtr From 4fed899306fca6a883e613e4662d1711cfc258d8 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 22 Apr 2024 11:36:49 -0700 Subject: [PATCH 40/40] use batch_size when logging metrics to prevent lightning "cant infer batch_size" error --- biogtr/models/gtr_runner.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index e5038f5e..7b83c8bd 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -1,4 +1,5 @@ """Module containing training, validation and inference logic.""" + import torch from biogtr.inference.tracker import Tracker from biogtr.inference import metrics @@ -85,7 +86,7 @@ def training_step( A dict containing the train loss plus any other metrics specified """ result = self._shared_eval_step(train_batch[0], mode="train") - self.log_metrics(result, "train") + self.log_metrics(result, len(train_batch[0]), "train") return result @@ -103,7 +104,7 @@ def validation_step( A dict containing the val loss plus any other metrics specified """ result = self._shared_eval_step(val_batch[0], mode="val") - self.log_metrics(result, "val") + self.log_metrics(result, len(val_batch[0]), "val") return result @@ -119,7 +120,7 @@ def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]: A dict containing the val loss plus any other metrics specified """ result = self._shared_eval_step(test_batch[0], mode="test") - self.log_metrics(result, "test") + self.log_metrics(result, len(test_batch[0]), "test") return result @@ -206,13 +207,20 @@ def configure_optimizers(self) -> dict: }, } - def log_metrics(self, result: dict, mode: str) -> None: + def log_metrics(self, result: dict, batch_size: int, mode: str) -> None: """Log metrics computed during evaluation. Args: result: A dict containing metrics to be logged. + batch_size: the size of the batch used to compute the metrics mode: One of {'train', 'test' or 'val'}. Used as prefix while logging. """ if result: for metric, val in result.items(): - self.log(f"{mode}_{metric}", val, on_step=True, on_epoch=True) + self.log( + f"{mode}_{metric}", + val, + batch_size=batch_size, + on_step=True, + on_epoch=True, + )