-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aadi/track local queues #20
Conversation
fix bug where tracks dont persist across chunks add verbosity for debugging(temp)
update base train yaml
fix some bugs with data structures
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## aadi-sleap-anchors #20 +/- ##
=====================================================
Coverage ? 68.24%
=====================================================
Files ? 28
Lines ? 2148
Branches ? 0
=====================================================
Hits ? 1466
Misses ? 682
Partials ? 0 ☔ View full report in Codecov by Sentry. |
Important Auto Review SkippedAuto reviews are disabled on base/target branches other than the default branch. Please add the base/target branch pattern to the list of additional branches to be reviewed in the settings. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughThe changes across the codebase primarily involve the introduction and integration of Changes
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
@coderabbitai review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 27
Configuration used: CodeRabbit UI
Files ignored due to filter (2)
- biogtr/training/configs/base.yaml
- tests/configs/base.yaml
Files selected for processing (33)
- .gitignore (1 hunks)
- biogtr/config.py (5 hunks)
- biogtr/data_structures.py (1 hunks)
- biogtr/datasets/base_dataset.py (6 hunks)
- biogtr/datasets/cell_tracking_dataset.py (5 hunks)
- biogtr/datasets/data_utils.py (6 hunks)
- biogtr/datasets/eval_dataset.py (1 hunks)
- biogtr/datasets/microscopy_dataset.py (5 hunks)
- biogtr/datasets/sleap_dataset.py (9 hunks)
- biogtr/datasets/tracking_dataset.py (2 hunks)
- biogtr/inference/init.py (1 hunks)
- biogtr/inference/boxes.py (2 hunks)
- biogtr/inference/metrics.py (9 hunks)
- biogtr/inference/post_processing.py (1 hunks)
- biogtr/inference/track.py (6 hunks)
- biogtr/inference/track_queue.py (1 hunks)
- biogtr/inference/tracker.py (6 hunks)
- biogtr/models/attention_head.py (2 hunks)
- biogtr/models/embedding.py (6 hunks)
- biogtr/models/global_tracking_transformer.py (2 hunks)
- biogtr/models/gtr_runner.py (11 hunks)
- biogtr/models/model_utils.py (3 hunks)
- biogtr/models/transformer.py (9 hunks)
- biogtr/training/losses.py (2 hunks)
- biogtr/training/train.py (3 hunks)
- biogtr/visualize.py (5 hunks)
- tests/fixtures/configs.py (1 hunks)
- tests/fixtures/torch.py (1 hunks)
- tests/test_data_structures.py (1 hunks)
- tests/test_datasets.py (12 hunks)
- tests/test_inference.py (4 hunks)
- tests/test_models.py (6 hunks)
- tests/test_training.py (4 hunks)
Files skipped from review due to trivial changes (8)
- .gitignore
- biogtr/datasets/tracking_dataset.py
- biogtr/inference/init.py
- biogtr/inference/track_queue.py
- biogtr/models/attention_head.py
- biogtr/models/embedding.py
- tests/fixtures/configs.py
- tests/fixtures/torch.py
Additional comments: 81
biogtr/config.py (5)
45-47: The update to the
__str__
method's docstring is a minor change that improves clarity.93-99: The variable assignments within the
get_gtr_runner
method appear unchanged, contrary to the summary which mentions reformatting.191-197: The addition of a trailing comma after
**dataloader_params
in theget_dataloader
method aligns with PEP 8's recommendation for multi-line collections.283-287: The addition of a trailing comma after the
accelerator
parameter in theget_trainer
method aligns with PEP 8's recommendation for function signatures.298-304: The
get_trainer
method has been updated with logic to setaccelerator
anddevices
in the trainer configuration if they are not present, which is not mentioned in the summary.biogtr/datasets/cell_tracking_dataset.py (4)
2-12: The import statements have been correctly updated to include
Instance
andFrame
frombiogtr.data_structures
, which is in line with the PR objectives and summary.27-31: The
__init__
method ofCellTrackingDataset
has been updated with new parameters and the super call has been updated accordingly, which is consistent with the PR objectives and summary.180-204: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [111-204]
The
get_instances
method has been updated to return a list ofFrame
objects, and the logic within the method has been updated to create and returnFrame
objects. This change is consistent with the PR objectives and summary.
- 111-111: The summary states that the input type of the
get_instances
method has been changed from a list of integers to a list ofFrame
objects. However, the code indicates that the input type remains a list of integers (label_idx: List[int], frame_idx: List[int]
). This discrepancy should be clarified.biogtr/datasets/data_utils.py (4)
67-73: The summary mentions a modification in the calculation of bounding box coordinates, but the provided hunk shows no changes in the logic of the
get_bbox
function. The code within the function is identical to the previous version, and the function signature remains unchanged.119-140: The summary indicates a revision in the calculation of the bounding box around an instance pose in the
pose_bbox
function, but the provided hunk shows no changes in the logic. The code within the function is identical to the previous version, and the function signature remains unchanged.213-218: The summary mentions a correction of a variable name from
xml_path
todata_path
in theparse_trackmate
function, but the provided hunk shows that the variabledata_path
is already being used consistently. There is no indication of a previous variable namedxml_path
.447-453: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [447-481]
The summary mentions a typo fix in the documentation string of the
view_training_batch
function, but the provided hunk does not show any changes in the documentation string. However, theif
condition within the function has been corrected to properly check for the type ofcmap
.- ax.imshow(data.T) if isinstance(cmap, None) else ax.imshow(data.T, cmap=cmap) + ax.imshow(data.T, cmap=cmap) if cmap is not None else ax.imshow(data.T)biogtr/datasets/microscopy_dataset.py (6)
2-8: The import changes are consistent with the PR objectives and the AI-generated summaries, which aim to improve tracking capabilities and data handling through the introduction of new data structures.
25-28: The method signature changes in the
__init__
method are consistent with the PR objectives and the AI-generated summaries.94-100: The logic to set
self.frame_idx
usingtorch.arange
andImage.open(video).n_frames
orlen(video)
is correct and supports different types of video input.104-112: The
get_indices
method is correctly implemented and does not present any issues.170-199: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [114-199]
The
get_instances
method has been refactored to use the newFrame
andInstance
data structures, which aligns with the PR objectives and the AI-generated summaries.
- 25-28: Please confirm if the commented-out line for setting the random seed (
np.random.seed(self.seed)
) is intentionally left in the code. If it's not needed, it should be removed to avoid confusion.biogtr/datasets/sleap_dataset.py (6)
5-12: The addition of imports for
warnings
and theFrame
andInstance
classes frombiogtr.data_structures
aligns with the PR's objective to enhance data structure usage for better tracking and data manipulation.29-33: The addition of the
verbose
argument to the__init__
method is consistent with the PR's objective to enhance tracking capabilities and provide more control over logging and output.205-270: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [140-270]
The modifications to the data processing logic, including the handling of frames and instances using the new
Frame
andInstance
data structures, align with the PR's objectives to improve object tracking across gaps in individual instances.
102-106: The
get_indices
method remains unchanged, which is consistent with the summary and indicates that the PR's changes did not affect this part of the code.205-270: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [140-270]
The changes to the
get_instances
method, including the use of the newFrame
andInstance
data structures, are significant and align with the PR's objectives to improve the tracking system's data handling capabilities.
- 205-270: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [167-270]
The changes to the augmentation logic, including the handling of keypoints and the application of augmentations to the images and poses, are consistent with the PR's objectives to improve data handling and tracking capabilities.
biogtr/inference/boxes.py (3)
1-3: The update to the import statement is correct as
Union
is not used in the file.58-59: The docstring correction in the
area
method is a minor grammatical improvement.56-62: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [5-62]
No further issues or changes are observed in the provided hunks.
biogtr/inference/track.py (3)
21-47: The changes to the
export_trajectories
function reflect the updated usage ofFrame
objects instead of dictionaries, aligning with the PR objectives and the summary provided. The logic within the function has been correctly adjusted to handle the new data structures.84-98: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [55-95]
The
inference
function has been updated to handleFrame
objects, which is consistent with the PR objectives and the summary. The logic within the function has been adapted to work with the new data structures, and the changes are correctly implemented.
- 107-112: The
main
function's logic has been updated to work with the newFrame
objects and the updatedinference
function. The changes are consistent with the PR objectives and the summary, and the function appears to handle the new data structures correctly.biogtr/inference/tracker.py (6)
- 24-105: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [15-56]
The addition of
max_gap
andverbose
parameters to theTracker
class's__init__
method is correctly implemented and aligns with the PR objectives to enhance tracking capabilities.
58-58: The
__call__
method has been correctly updated to accept a list ofFrame
objects.84-102: The logic within the
track
method has been updated to handle instances of theFrame
class and their associated attributes.121-187: The
sliding_inference
method has been updated to handle instances of theFrame
class and their associated attributes.215-328: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [189-392]
The
_run_global_tracker
method has been updated to handle instances of theFrame
class and their associated attributes.
- 1-12: The additional imports for
warnings
,Frame
, andTrackQueue
are correctly implemented.biogtr/models/global_tracking_transformer.py (4)
1-7: The import of
Frame
frombiogtr.data_structures
is correctly added to support the new data structure usage in theforward
method.101-123: The
forward
method has been correctly updated to accept a list ofFrame
objects, and the logic within the method has been adapted to work with these new data structures.111-123: The logic to extract features from frames only if they do not already have them is a good use of the
Frame
object's methods and ensures that features are not recomputed unnecessarily.121-123: The return statement correctly provides the association predictions and embeddings, which is consistent with the method's documentation and the expected output of the
forward
method.biogtr/models/gtr_runner.py (8)
21-33: The default values for
metrics
andpersistent_tracking
have been updated to improve clarity and functionality. The documentation has been adjusted to reflect these changes.60-72: The
forward
method now checks for detected instances using thenum_detected
attribute before proceeding with the model prediction, which is a logical improvement.56-80: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [74-90]
The
training_step
method has been updated with improved documentation and a more structured approach to logging metrics.
- 86-98: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [92-108]
The
validation_step
method has been updated with improved documentation and a more structured approach to logging metrics.
- 102-114: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [110-124]
The
test_step
method has been updated with improved documentation and a more structured approach to logging metrics.
- 120-130: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [126-141]
The
predict_step
method has been updated to always use persistent tracking during inference, which aligns with the objective of improving tracking across individual instance gaps.
- 205-218: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [178-207]
The
configure_optimizers
method has been structured to be overridden by config, with a default setup for the optimizer and scheduler provided. This allows for flexibility and customization.
- 209-218: The
log_metrics
method has been enhanced with type annotations and improved documentation, which aids in code readability and maintainability.biogtr/models/transformer.py (7)
11-21: The documentation at the top of the file has been updated to reflect the addition of fixed embeddings over boxes, aligning with the PR's objective to enhance tracking capabilities.
162-163: The
_reset_parameters
method has been updated to initialize model weights using the Xavier uniform distribution for tensors with more than one dimension. This is a standard practice for initializing neural network weights and should help with convergence during training.214-245: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [165-272]
The
forward
method of theTransformer
class has been significantly refactored to accommodate the newFrame
objects and thequery_frame
parameter. The method now concatenates features from theFrame
objects and handles positional embeddings based on theembedding_meta
configuration. The logic for handling query indices and the attention mechanism has been updated accordingly. This aligns with the PR's objective to improve object tracking across individual instance gaps.
294-298: The
forward
method of theTransformerEncoder
class has been updated to optionally accept a positional embedding tensor. This change is consistent with the overall PR's objective of enhancing the tracking system's capabilities.339-345: The
forward
method of theTransformerDecoder
class has been updated to accept positional embeddings for both the target sequence and the memory sequence. This is in line with the changes made to theTransformer
class and supports the PR's goal of improving tracking.416-420: The
forward
method of theTransformerEncoderLayer
class has been updated to include positional embeddings in the input sequence if provided. This change is consistent with the updates to theTransformerEncoder
andTransformerDecoder
classes and supports the PR's enhancements.490-494: The
forward
method of theTransformerDecoderLayer
class has been updated to handle positional embeddings for both the target and memory sequences. This change is consistent with the updates to theTransformerDecoder
class and supports the PR's enhancements.biogtr/training/losses.py (4)
2-2: The import change from dictionaries to
Frame
objects is consistent with the PR objectives and the summary provided.36-38: The update to the
forward
method signature to acceptFrame
objects aligns with the PR objectives and the summary.49-50: The attribute access changes to utilize
Frame
object methods are consistent with the PR objectives and the summary.53-53: The change in the function call to
get_boxes_times
to passFrame
objects is consistent with the PR objectives and the summary.biogtr/training/train.py (1)
- 41-52: The refactoring of the input prompt logic for obtaining the task index is a good fallback mechanism for when the
POD_INDEX
environment variable is not set. This ensures that the batch job can still proceed with manual input.biogtr/visualize.py (4)
2-13: The summary indicates that the import statement for
matplotlib.pyplot
was removed, but this change is not visible in the provided hunks. Please verify if this change occurred in a part of the file not shown here.18-19: The change in the docstring for
fill_missing
from "Fills" to "Fill" is a minor grammatical correction and is acceptable.67-71: The addition of the
poses
parameter to theannotate_video
function may require updates to all calls to this function to accommodate the new parameter, especially if they rely on positional arguments.280-282: The refactoring of the
main
function's docstring to a one-liner is a stylistic change and is acceptable.tests/test_datasets.py (5)
54-58: The update from direct dictionary access to using the
get_gt_track_ids()
method aligns with the PR's objective to encapsulate attributes and provide methods for easier manipulation and access.65-71: The changes to the
n_chunks
parameter in theSleapDataset
instantiation are not mentioned in the summary. This could be an oversight, as altering the chunking behavior can have significant effects on dataset partitioning and should be documented.54-58: The changes to the assertions in the test functions are not mentioned in the summary. These changes are significant as they validate the expected behavior of the dataset with the new data structures.
229-229: The changes to the assertions comparing
gt_track_ids_1
andgt_track_ids_2
are not mentioned in the summary. These changes are significant as they ensure the consistency of the dataset's behavior with and without thegt_list
provided.384-390: The changes to the augmentation tests in the
test_augmentations
function are not mentioned in the summary. These changes are significant as they test the robustness of the dataset under transformations and should be documented.tests/test_inference.py (4)
5-7: The import of
Frame
andInstance
frombiogtr.data_structures
is consistent with the PR objectives and the changes described in the summary.19-37: The refactoring to use
Frame
andInstance
objects for tracking logic is consistent with the PR objectives and the summary.62-81: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [78-152]
The post-processing tests appear to be unchanged and not directly related to the new data structures. Verify if the commented-out
pytest.mark.parametrize
is intended to be removed or if it's part of a refactor that should be completed.
- 155-188: The creation of
Frame
andInstance
objects for metrics testing is consistent with the PR objectives and the summary.tests/test_models.py (3)
208-225: The test
test_transformer_basic
has been updated to use the newFrame
andInstance
data structures. This change is consistent with the PR objectives and the summary provided.270-286: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [270-307]
The test
test_transformer_embedding
has been updated to use the newFrame
andInstance
data structures and includes additional assertions to check the size of the output embeddings. This change is consistent with the PR objectives and the summary provided.
- 314-331: > Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [314-353]
The test
test_tracking_transformer
has been updated to use the newFrame
andInstance
data structures and includes additional assertions to check the size of the output embeddings. This change is consistent with the PR objectives and the summary provided.tests/test_training.py (4)
2-8: The imports and setup in the beginning of the file are consistent with the PR objectives and the summary provided.
19-39: The changes in the
test_asso_loss
function correctly reflect the use of the newFrame
andInstance
data structures, aligning with the PR objectives and the summary.46-77: The setup of the
test_basic_gtr_runner
function with the newFrame
andInstance
data structures is consistent with the PR objectives and the summary.98-120: The calls to
training_step
,validation_step
, andtest_step
within thetest_basic_gtr_runner
function are updated to pass a list containingbatch
instead ofbatch
directly, which is in line with the PR objectives and the summary.
|
||
# for now set equal since detections are fixed | ||
pred_box, pred_time = get_boxes_times(instances) | ||
pred_box, pred_time = get_boxes_times(frames) | ||
target_box, target_time = pred_box, pred_time | ||
|
||
# todo: we should maybe reconsider how we label gt instances. The second |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider removing the commented-out code that references the old implementation using dictionaries to clean up the file and avoid confusion.
biogtr/models/model_utils.py
Outdated
"""Module containing model helper functions.""" | ||
from copy import deepcopy | ||
from typing import Dict, List, Tuple, Iterable | ||
from typing import List, Tuple, Iterable | ||
from pytorch_lightning import loggers | ||
from biogtr.data_structures import Frame | ||
import torch | ||
|
||
|
||
def get_boxes_times(instances: List[Dict]) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Extracts the bounding boxes and frame indices from the input list of instances. | ||
def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Extract the bounding boxes and frame indices from the input list of instances. | ||
|
||
Args: | ||
instances (List[Dict]): List of instance dictionaries |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This review was outside the patches, and no patch overlapping with it was found. Original lines [63-64]
Potential bug in error handling.
- if optimizer_class is None:
+ optimizer_class = None
Note: This review was outside the patches, and no patch overlapping with it was found. Original lines [83-84]
Potential bug in error handling.
- if scheduler_class is None:
+ scheduler_class = None
def test_instance(): | ||
"""Test Instance object logic.""" | ||
|
||
gt_track_id = 0 | ||
pred_track_id = 0 | ||
bbox = torch.randn((1, 4)) | ||
crop = torch.randn((1, 3, 128, 128)) | ||
features = torch.randn((1, 64)) | ||
|
||
instance = Instance( | ||
gt_track_id=gt_track_id, | ||
pred_track_id=pred_track_id, | ||
bbox=bbox, | ||
crop=crop, | ||
features=features, | ||
) | ||
|
||
assert instance.has_gt_track_id() | ||
assert instance.gt_track_id.item() == gt_track_id | ||
assert instance.has_pred_track_id() | ||
assert instance.pred_track_id.item() == pred_track_id | ||
assert instance.has_bbox() | ||
assert torch.equal(instance.bbox, bbox) | ||
assert instance.has_features() | ||
assert torch.equal(instance.features, features) | ||
|
||
instance.gt_track_id = 1 | ||
instance.pred_track_id = 1 | ||
instance.bbox = torch.randn((1, 4)) | ||
instance.crop = torch.randn((1, 3, 128, 128)) | ||
instance.features = torch.randn((1, 64)) | ||
|
||
assert instance.has_gt_track_id() | ||
assert instance.gt_track_id.item() != gt_track_id | ||
assert instance.has_pred_track_id() | ||
assert instance.pred_track_id.item() != pred_track_id | ||
assert instance.has_bbox() | ||
assert not torch.equal(instance.bbox, bbox) | ||
assert instance.has_features() | ||
assert not torch.equal(instance.features, features) | ||
|
||
instance.gt_track_id = None | ||
instance.pred_track_id = -1 | ||
instance.bbox = None | ||
instance.crop = None | ||
instance.features = None | ||
|
||
assert not instance.has_gt_track_id() | ||
assert instance.gt_track_id.shape[0] == 0 | ||
assert not instance.has_pred_track_id() | ||
assert instance.pred_track_id.item() != pred_track_id | ||
assert not instance.has_bbox() | ||
assert not torch.equal(instance.bbox, bbox) | ||
assert not instance.has_features() | ||
assert not torch.equal(instance.features, features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests for the Instance
class seem comprehensive, covering the creation of an instance, attribute checks, and the behavior when setting attributes to None
. However, it would be beneficial to add comments explaining the purpose of each assertion block for better readability and maintainability.
def test_frame(): | ||
n_detected = 2 | ||
n_traj = 3 | ||
video_id = 0 | ||
frame_id = 0 | ||
img_shape = torch.tensor([3, 1024, 1024]) | ||
asso_output = torch.randn(n_detected, 16) | ||
traj_score = torch.randn(n_detected, n_traj) | ||
matches = ([0, 1], [0, 1]) | ||
|
||
instances = [] | ||
for i in range(n_detected): | ||
instances.append( | ||
Instance( | ||
gt_track_id=i, | ||
pred_track_id=i, | ||
bbox=torch.randn(1, 4), | ||
crop=torch.randn(1, 3, 64, 64), | ||
features=torch.randn(1, 64), | ||
) | ||
) | ||
frame = Frame( | ||
video_id=video_id, frame_id=frame_id, img_shape=img_shape, instances=instances | ||
) | ||
|
||
assert frame.video_id.item() == video_id | ||
assert frame.frame_id.item() == frame_id | ||
assert torch.equal(frame.img_shape, img_shape) | ||
assert frame.num_detected == n_detected | ||
assert frame.has_instances() | ||
assert len(frame.instances) == n_detected | ||
assert frame.has_gt_track_ids() | ||
assert len(frame.get_gt_track_ids()) == n_detected | ||
assert frame.has_pred_track_ids() | ||
assert len(frame.get_pred_track_ids()) == n_detected | ||
assert not frame.has_matches() | ||
assert not frame.has_asso_output() | ||
assert not frame.has_traj_score() | ||
|
||
frame.asso_output = asso_output | ||
frame.add_traj_score("initial", traj_score) | ||
frame.matches = matches | ||
|
||
assert frame.has_matches() | ||
assert frame.matches == matches | ||
assert frame.has_asso_output() | ||
assert torch.equal(frame.asso_output, asso_output) | ||
assert frame.has_traj_score() | ||
assert torch.equal(frame.get_traj_score("initial"), traj_score) | ||
|
||
frame.instances = [] | ||
|
||
assert frame.video_id.item() == video_id | ||
assert frame.num_detected == 0 | ||
assert not frame.has_instances() | ||
assert len(frame.instances) == 0 | ||
assert not frame.has_gt_track_ids() | ||
assert not len(frame.get_gt_track_ids()) | ||
assert not frame.has_pred_track_ids() | ||
assert len(frame.get_pred_track_ids()) == 0 | ||
assert frame.has_matches() | ||
assert frame.has_asso_output() | ||
assert frame.has_traj_score() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests for the Frame
class cover a variety of scenarios, including the creation of a frame with instances, attribute checks, and the behavior when modifying the instances
attribute. Similar to the Instance
tests, adding comments to explain the purpose of each assertion block would improve readability.
def test_track_queue(): | ||
window_size = 8 | ||
max_gap = 10 | ||
img_shape = (3, 1024, 1024) | ||
n_instances_per_frame = [2] * window_size | ||
|
||
frames = [] | ||
instances_per_frame = [] | ||
|
||
tq = TrackQueue(window_size, max_gap) | ||
for i in range(window_size): | ||
instances = [] | ||
for j in range(n_instances_per_frame[i]): | ||
instances.append(Instance(gt_track_id=j, pred_track_id=j)) | ||
instances_per_frame.append(instances) | ||
frame = Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) | ||
frames.append(frame) | ||
|
||
tq.add_frame(frame) | ||
|
||
assert len(tq) == sum(n_instances_per_frame[1:]) | ||
assert tq.n_tracks == max(n_instances_per_frame) | ||
assert tq.tracks == [i for i in range(max(n_instances_per_frame))] | ||
assert len(tq.collate_tracks()) == window_size - 1 | ||
assert all([gap == 0 for gap in tq._curr_gap.values()]) | ||
assert tq.curr_track == max(n_instances_per_frame) - 1 | ||
|
||
tq.add_frame( | ||
Frame( | ||
video_id=0, | ||
frame_id=window_size + 1, | ||
img_shape=img_shape, | ||
instances=[Instance(gt_track_id=0, pred_track_id=0)], | ||
) | ||
) | ||
|
||
assert len(tq._queues[0]) == window_size - 1 | ||
assert tq._curr_gap[0] == 0 | ||
assert tq._curr_gap[max(n_instances_per_frame) - 1] == 1 | ||
|
||
tq.add_frame( | ||
Frame( | ||
video_id=0, | ||
frame_id=window_size + 1, | ||
img_shape=img_shape, | ||
instances=[ | ||
Instance(gt_track_id=1, pred_track_id=1), | ||
Instance( | ||
gt_track_id=max(n_instances_per_frame), | ||
pred_track_id=max(n_instances_per_frame), | ||
), | ||
], | ||
) | ||
) | ||
|
||
assert len(tq._queues[max(n_instances_per_frame)]) == 1 | ||
assert tq._curr_gap[1] == 0 | ||
assert tq._curr_gap[0] == 1 | ||
|
||
for i in range(max_gap): | ||
tq.add_frame( | ||
Frame( | ||
video_id=0, | ||
frame_id=window_size + i + 1, | ||
img_shape=img_shape, | ||
instances=[Instance(gt_track_id=0, pred_track_id=0)], | ||
) | ||
) | ||
|
||
assert tq.n_tracks == 1 | ||
assert tq.curr_track == max(n_instances_per_frame) | ||
assert 0 in tq._queues.keys() | ||
|
||
tq.end_tracks() | ||
|
||
assert len(tq) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests for the TrackQueue
class check the queue's behavior with different numbers of frames and instances, as well as the handling of gaps. It would be good to add assertions to check the correctness of the collate_tracks
method's output and to ensure that the end_tracks
method correctly finalizes the tracking. Additionally, comments explaining the test scenarios would enhance understanding.
biogtr/inference/metrics.py
Outdated
Args: | ||
instances: a list of dicts where each dict corresponds to a frame and | ||
contains the video_id, frame_id, gt labels and predicted labels | ||
instances: a list of Frames containing the video_id, frame_id, | ||
gt labels and predicted labels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring argument name instances
should be updated to frames
to match the function signature.
- instances: a list of Frames containing the video_id, frame_id,
+ frames: a list of Frames containing the video_id, frame_id,
Committable suggestion
❗ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
Args: | |
instances: a list of dicts where each dict corresponds to a frame and | |
contains the video_id, frame_id, gt labels and predicted labels | |
instances: a list of Frames containing the video_id, frame_id, | |
gt labels and predicted labels | |
Args: | |
frames: a list of Frames containing the video_id, frame_id, | |
gt labels and predicted labels |
biogtr/inference/metrics.py
Outdated
else: | ||
warnings.warn("No instances detected!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider providing more context in the warning message, such as the video_id or frame_id, to help with debugging.
biogtr/datasets/eval_dataset.py
Outdated
def __getitem__(self, idx: int) -> List[Frame]: | ||
"""Get an element of the dataset. | ||
|
||
Args: | ||
idx: the index of the batch. Note this is not the index of the video | ||
or the frame. | ||
|
||
Returns: | ||
A list of dicts where each dict corresponds a frame in the chunk and | ||
each value is a `torch.Tensor`. Dict elements are the video id, frame id, and gt/pred track ids | ||
|
||
A list of Frames where frames contain instances w gt and pred track ids + bboxes. | ||
""" | ||
labels = [{"video_id": gt_frame['video_id'], | ||
"frame_id": gt_frame['video_id'], | ||
"gt_track_ids": gt_frame['gt_track_ids'], | ||
"pred_track_ids": pred_frame['gt_track_ids'], | ||
"bboxes": pred_frame["bboxes"] | ||
} for gt_frame, pred_frame in zip(self.gt_dataset[idx], self.pred_dataset[idx])] | ||
|
||
return labels No newline at end of file | ||
gt_batch = self.gt_dataset[idx] | ||
pred_batch = self.pred_dataset[idx] | ||
|
||
eval_frames = [] | ||
for gt_frame, pred_frame in zip(gt_batch, pred_batch): | ||
eval_instances = [] | ||
for gt_instance, pred_instance in zip( | ||
gt_frame.instances, pred_frame.instances | ||
): | ||
eval_instances.append( | ||
Instance( | ||
gt_track_id=gt_instance.gt_track_id, | ||
pred_track_id=pred_instance.gt_track_id, | ||
bbox=pred_instance.bbox, | ||
) | ||
) | ||
eval_frames.append( | ||
Frame( | ||
video_id=gt_frame.video_id, | ||
frame_id=gt_frame.frame_id, | ||
img_shape=gt_frame.img_shape, | ||
instances=eval_instances, | ||
) | ||
) | ||
|
||
return eval_frames |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __getitem__
method has been refactored to return a list of Frame
objects. Ensure that the logic correctly creates Frame
and Instance
objects and that the ground truth and predicted track IDs are assigned to the correct attributes. Specifically, line 50 seems to incorrectly assign pred_instance.gt_track_id
to pred_track_id
instead of pred_instance.pred_track_id
.
eval_instances.append(
Instance(
gt_track_id=gt_instance.gt_track_id,
- pred_track_id=pred_instance.gt_track_id,
+ pred_track_id=pred_instance.pred_track_id,
bbox=pred_instance.bbox,
)
)
Committable suggestion
❗ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def __getitem__(self, idx: int) -> List[Frame]: | |
"""Get an element of the dataset. | |
Args: | |
idx: the index of the batch. Note this is not the index of the video | |
or the frame. | |
Returns: | |
A list of dicts where each dict corresponds a frame in the chunk and | |
each value is a `torch.Tensor`. Dict elements are the video id, frame id, and gt/pred track ids | |
A list of Frames where frames contain instances w gt and pred track ids + bboxes. | |
""" | |
labels = [{"video_id": gt_frame['video_id'], | |
"frame_id": gt_frame['video_id'], | |
"gt_track_ids": gt_frame['gt_track_ids'], | |
"pred_track_ids": pred_frame['gt_track_ids'], | |
"bboxes": pred_frame["bboxes"] | |
} for gt_frame, pred_frame in zip(self.gt_dataset[idx], self.pred_dataset[idx])] | |
return labels | |
\ No newline at end of file | |
gt_batch = self.gt_dataset[idx] | |
pred_batch = self.pred_dataset[idx] | |
eval_frames = [] | |
for gt_frame, pred_frame in zip(gt_batch, pred_batch): | |
eval_instances = [] | |
for gt_instance, pred_instance in zip( | |
gt_frame.instances, pred_frame.instances | |
): | |
eval_instances.append( | |
Instance( | |
gt_track_id=gt_instance.gt_track_id, | |
pred_track_id=pred_instance.gt_track_id, | |
bbox=pred_instance.bbox, | |
) | |
) | |
eval_frames.append( | |
Frame( | |
video_id=gt_frame.video_id, | |
frame_id=gt_frame.frame_id, | |
img_shape=gt_frame.img_shape, | |
instances=eval_instances, | |
) | |
) | |
return eval_frames | |
def __getitem__(self, idx: int) -> List[Frame]: | |
"""Get an element of the dataset. | |
Args: | |
idx: the index of the batch. Note this is not the index of the video | |
or the frame. | |
Returns: | |
A list of Frames where frames contain instances w gt and pred track ids + bboxes. | |
""" | |
gt_batch = self.gt_dataset[idx] | |
pred_batch = self.pred_dataset[idx] | |
eval_frames = [] | |
for gt_frame, pred_frame in zip(gt_batch, pred_batch): | |
eval_instances = [] | |
for gt_instance, pred_instance in zip( | |
gt_frame.instances, pred_frame.instances | |
): | |
eval_instances.append( | |
Instance( | |
gt_track_id=gt_instance.gt_track_id, | |
pred_track_id=pred_instance.pred_track_id, | |
bbox=pred_instance.bbox, | |
) | |
) | |
eval_frames.append( | |
Frame( | |
video_id=gt_frame.video_id, | |
frame_id=gt_frame.frame_id, | |
img_shape=gt_frame.img_shape, | |
instances=eval_instances, | |
) | |
) | |
return eval_frames |
biogtr/data_structures.py
Outdated
gt_track_id: int = None, | ||
pred_track_id: int = -1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default values for gt_track_id
and pred_track_id
are inconsistent with their expected types. Consider using 0
or another placeholder that aligns with the expected integer type.
- gt_track_id: int = None,
+ gt_track_id: int = 0,
- pred_track_id: int = -1,
+ pred_track_id: int = 0,
Committable suggestion
❗ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
gt_track_id: int = None, | |
pred_track_id: int = -1, | |
gt_track_id: int = 0, | |
pred_track_id: int = 0, |
biogtr/data_structures.py
Outdated
self._gt_track_id = torch.tensor([gt_track_id]) | ||
else: | ||
self._gt_track_id = torch.tensor([]) | ||
|
||
if pred_track_id is not None: | ||
self._pred_track_id = torch.tensor([pred_track_id]) | ||
else: | ||
self._pred_track_id = torch.tensor([]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider whether it is necessary to store gt_track_id
and pred_track_id
as tensors since they represent single integer values.
save config to wandblogger enable use of profiler use mixed precision in training
use simple tracker ids to initialize tracks instead of random use clone instead of deepcopy save out all asso_matrices
don't track during training by default
save pose as dictionary fix tests
check box coords to make sure x1x2 or y1y2 are not negative save embeddings to model for gradient update
@coderabbitai review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 27
Actionable comments outside the diff hunks (8)
tests/test_training.py (1)
Line range hint
45-45
: Address the unused local variable to clean up the code.- feats = 128
The variable
feats
is assigned but never used. Consider removing it if it's not needed.biogtr/inference/track.py (1)
Line range hint
137-137
: Correct the f-string to include placeholders where necessary.- print(f"Pod Index: {index}") + print(f"Pod Index: {index}") # Ensure that `index` is defined and used correctly in the f-string.This change ensures that the f-string is used correctly, displaying the value of
index
.tests/test_inference.py (1)
Line range hint
91-91
: Address the unused local variable to clean up the code.- N = N_t * T
The variable
N
is assigned but never used. Consider removing it if it's not needed.biogtr/models/model_utils.py (2)
Line range hint
63-64
: Address potential bug in error handling.- if optimizer_class is None: + optimizer_class = NoneThis change corrects a potential bug where the
optimizer_class
variable might not be set correctly, leading to runtime errors.
Line range hint
83-84
: Address potential bug in error handling.- if scheduler_class is None: + scheduler_class = NoneThis change corrects a potential bug where the
scheduler_class
variable might not be set correctly, leading to runtime errors.biogtr/datasets/microscopy_dataset.py (1)
Line range hint
86-86
: Avoid using lambda expressions for defining functions due to potential issues with serialization and readability. Consider defining a proper function instead.- parser = lambda x: data_utils.parse_synthetic(x, source=source) + def parser(x): + return data_utils.parse_synthetic(x, source=source)biogtr/datasets/data_utils.py (2)
Line range hint
66-66
: Useisinstance()
for type checking.- if type(size) == int: + if isinstance(size, int):Replace type comparison with
isinstance()
for more Pythonic and reliable type checking.
Line range hint
342-342
: Avoid using bareexcept
.- except: + except Exception as e:Specify the exception type to avoid catching unexpected exceptions, which can make debugging more difficult.
|
||
def track(self, model: GlobalTrackingTransformer, instances: list[dict], all_instances: list = None): | ||
def track(self, model: GlobalTrackingTransformer, frames: list[dict]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method signature for track
should accept a list of Frame
objects instead of a list of dictionaries. This change aligns with the PR objectives and ensures type consistency.
- def track(self, model: GlobalTrackingTransformer, frames: list[dict]):
+ def track(self, model: GlobalTrackingTransformer, frames: list[Frame]):
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def track(self, model: GlobalTrackingTransformer, frames: list[dict]): | |
def track(self, model: GlobalTrackingTransformer, frames: list[Frame]): |
self.track_queue.clear() | ||
|
||
if self.verbose: | ||
warnings.warn(f"Clearing Queue after tracking") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning message in the verbose logging uses an f-string without placeholders. Consider revising the message or removing the f-string if not needed.
- warnings.warn(f"Clearing Queue after tracking")
+ warnings.warn("Clearing Queue after tracking")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
warnings.warn(f"Clearing Queue after tracking") | |
warnings.warn("Clearing Queue after tracking") |
biogtr/inference/metrics.py
Outdated
from biogtr.inference.boxes import Boxes | ||
import torch | ||
from biogtr.data_structures import Frame | ||
import warnings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import warnings
.
- import warnings
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
import warnings |
75da884
to
4fed899
Compare
Summary by CodeRabbit
New Features
Documentation
Refactor
Frame
andInstance
classes for enhanced data management.TrackQueue
class for improved tracking functionality.Style
Tests
Bug Fixes
Chores