Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sleap-io video backend #79

Open
wants to merge 17 commits into
base: aadi/batch-inference-eval
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/configs/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# DREEM Config API

We utilize `.yaml` based configs with `hydra` and `omegaconf` for config parsing.
We utilize `.yaml` based configs with [`hydra`](https://hydra.cc) and [`omegaconf`](https://omegaconf.readthedocs.io/en/2.3_branch/) for config parsing.
2 changes: 1 addition & 1 deletion docs/configs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Here, we describe the hyperparameters used for setting up training. Please see [here](./training.md#example-config) for an example training config.

> Note: for using defaults, simply leave the field blank or don't include the key. Using `null` will initialize the value to `None` e.g
> Note: for using defaults, simply leave the field blank or don't include the key. Using `null` will initialize the value to `None` which we use to represent turning off certain features such as logging, early stopping etc. e.g
> ```YAML
> model:
> d_model: #defaults to 1024
Expand Down
3 changes: 2 additions & 1 deletion dreem/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dreem.datasets import data_utils
from dreem.io import Frame
from torch.utils.data import Dataset
from typing import Union
import numpy as np
import torch

Expand All @@ -15,7 +16,7 @@ def __init__(
label_files: list[str],
vid_files: list[str],
padding: int,
crop_size: int,
crop_size: Union[int, list[int]],
chunk: bool,
clip_length: int,
mode: str,
Expand Down
74 changes: 59 additions & 15 deletions dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numpy as np
import sleap_io as sio
import random
from pathlib import Path
import logging
from typing import Union, Optional
from dreem.io import Instance, Frame
from dreem.datasets import data_utils, BaseDataset
from torchvision.transforms import functional as tvf
Expand All @@ -21,8 +23,9 @@ def __init__(
self,
slp_files: list[str],
video_files: list[str],
data_dirs: Optional[list[str]] = None,
padding: int = 5,
crop_size: int = 128,
crop_size: Union[int, list[int]] = 128,
anchors: int | list[str] | str = "",
chunk: bool = True,
clip_length: int = 500,
Expand All @@ -32,14 +35,19 @@ def __init__(
n_chunks: int | float = 1.0,
seed: int | None = None,
verbose: bool = False,
normalize_image: bool = True,
):
"""Initialize SleapDataset.

Args:
slp_files: a list of .slp files storing tracking annotations
video_files: a list of paths to video files
data_dirs: a path, or a list of paths to data directories. If provided, crop_size should be a list of integers
with the same length as data_dirs.
padding: amount of padding around object crops
crop_size: the size of the object crops
crop_size: the size of the object crops. Can be either:
- An integer specifying a single crop size for all objects
- A list of integers specifying different crop sizes for different data directories
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
Expand All @@ -64,6 +72,7 @@ def __init__(
Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
seed: set a seed for reproducibility
verbose: boolean representing whether to print
normalize_image: whether to normalize the image to [0, 1]
"""
super().__init__(
slp_files,
Expand All @@ -79,6 +88,9 @@ def __init__(
)

self.slp_files = slp_files
self.data_dirs = (
data_dirs # empty list, list of paths, or string of single path
)
self.video_files = video_files
self.padding = padding
self.crop_size = crop_size
Expand All @@ -88,14 +100,33 @@ def __init__(
self.handle_missing = handle_missing.lower()
self.n_chunks = n_chunks
self.seed = seed

self.normalize_image = normalize_image
if self.data_dirs is None:
self.data_dirs = []
if isinstance(anchors, int):
self.anchors = anchors
elif isinstance(anchors, str):
self.anchors = [anchors]
else:
self.anchors = anchors

if not isinstance(self.data_dirs, list):
self.data_dirs = [self.data_dirs]

if not isinstance(self.crop_size, list):
# make a list so its handled consistently if multiple crops are used
if len(self.data_dirs) > 0: # for test mode, data_dirs is []
self.crop_size = [self.crop_size] * len(self.data_dirs)
else:
self.crop_size = [self.crop_size]

if len(self.data_dirs) > 0 and len(self.crop_size) != len(self.data_dirs):
raise ValueError(
f"If a list of crop sizes or data directories are given,"
f"they must have the same length but got {len(self.crop_size)} "
f"and {len(self.data_dirs)}"
)

if (
isinstance(self.anchors, list) and len(self.anchors) == 0
) or self.anchors == 0:
Expand All @@ -106,7 +137,7 @@ def __init__(
# if self.seed is not None:
# np.random.seed(self.seed)
self.labels = [sio.load_slp(slp_file) for slp_file in self.slp_files]
self.videos = [imageio.get_reader(vid_file) for vid_file in self.vid_files]
self.vid_readers = {}
# do we need this? would need to update with sleap-io

# for label in self.labels:
Expand Down Expand Up @@ -140,6 +171,17 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

video_name = self.video_files[label_idx]

# get the correct crop size based on the video
video_par_path = Path(video_name).parent
if len(self.data_dirs) > 0:
crop_size = self.crop_size[0]
for j, data_dir in enumerate(self.data_dirs):
if Path(data_dir) == video_par_path:
crop_size = self.crop_size[j]
break
else:
crop_size = self.crop_size[0]

vid_reader = self.videos[label_idx]

# img = vid_reader.get_data(0)
Expand All @@ -162,12 +204,12 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
lf = video[frame_ind]

try:
img = vid_reader.get_data(int(lf.frame_idx))
except IndexError as e:
logger.warning(
f"Could not read frame {frame_ind} from {video_name} due to {e}"
)
continue
img = lf.image
except FileNotFoundError as e:
if video_name not in self.vid_readers:
self.vid_readers[video_name] = sio.load_video(video_name)
vid_reader = self.vid_readers[video_name]
img = vid_reader[lf.frame_idx]

if len(img.shape) == 2:
img = img.expand_dims(-1)
Expand All @@ -179,7 +221,9 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
) # convert to grayscale to rgb

if np.issubdtype(img.dtype, np.integer): # convert int to float
img = img.astype(np.float32) / 255
img = img.astype(np.float32)
if self.normalize_image:
img = img / 255

n_instances_dropped = 0

Expand Down Expand Up @@ -316,15 +360,15 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
data_utils.get_bbox(centroid, crop_size),
padding=self.padding,
)

if bbox.isnan().all():
crop = torch.zeros(
c,
self.crop_size + 2 * self.padding,
self.crop_size + 2 * self.padding,
crop_size + 2 * self.padding,
crop_size + 2 * self.padding,
dtype=img.dtype,
)
else:
Expand Down Expand Up @@ -370,5 +414,5 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

def __del__(self):
"""Handle file closing before garbage collection."""
for reader in self.videos:
for reader in self.vid_readers:
reader.close()
6 changes: 3 additions & 3 deletions dreem/inference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:

checkpoint = eval_cfg.cfg.ckpt_path

logger.info(f"Testing model saved at {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)
logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())

model = GTRRunner.load_from_checkpoint(checkpoint, strict=False)
model.tracker_cfg = eval_cfg.cfg.tracker
model.tracker = Tracker(**model.tracker_cfg)

Expand All @@ -61,7 +61,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"persistent_tracking", False
)
logger.info(f"Computing the following metrics:")
logger.info(model.metrics.test)
logger.info(model.metrics["test"])
model.test_results["save_path"] = eval_cfg.cfg.runner.save_path
logger.info(f"Saving results to {model.test_results['save_path']}")

Expand Down
54 changes: 32 additions & 22 deletions dreem/inference/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,44 +123,54 @@ def weight_iou(
def filter_max_center_dist(
asso_output: torch.Tensor,
max_center_dist: float = 0,
k_boxes: torch.Tensor | None = None,
nonk_boxes: torch.Tensor | None = None,
id_inds: torch.Tensor | None = None,
curr_frame_boxes: torch.Tensor | None = None,
prev_frame_boxes: torch.Tensor | None = None,
) -> torch.Tensor:
"""Filter trajectory score by distances between objects across frames.

Args:
asso_output: An N_t x N association matrix
max_center_dist: The euclidean distance threshold between bboxes
k_boxes: The bounding boxes in the current frame
nonk_boxes: the boxes not in the current frame
id_inds: track ids
curr_frame_boxes: the raw bbox coords of the current frame instances
prev_frame_boxes: the raw bbox coords of the previous frame instances

Returns:
An N_t x N association matrix
"""
if max_center_dist is not None and max_center_dist > 0:
assert (
k_boxes is not None and nonk_boxes is not None and id_inds is not None
), "Need `k_boxes`, `nonk_boxes`, and `id_ind` to filter by `max_center_dist`"
k_ct = (k_boxes[:, :, :2] + k_boxes[:, :, 2:]) / 2
k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k

nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2

dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(
dim=-1
) # n_k x Np

norm_dist = dist / (k_s[:, None, :] + 1e-8)
norm_dist = dist.mean(axis=-1) # n_k x Np
curr_frame_boxes is not None
and prev_frame_boxes is not None
and id_inds is not None
), "Need `curr_frame_boxes`, `prev_frame_boxes`, and `id_ind` to filter by `max_center_dist`"

k_ct = (curr_frame_boxes[:, :, :2] + curr_frame_boxes[:, :, 2:]) / 2
# k_s = ((curr_frame_boxes[:, :, 2:] - curr_frame_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k
# nonk boxes are only from previous frame rather than entire window
nonk_ct = (prev_frame_boxes[:, :, :2] + prev_frame_boxes[:, :, 2:]) / 2

# pairwise euclidean distance in units of pixels
dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(dim=-1) ** (
1 / 2
) # n_k x n_nonk
# norm_dist = dist / (k_s[:, None, :] + 1e-8)

valid = dist.squeeze() < max_center_dist # n_k x n_nonk
# handle case where id_inds and valid is a single value
# handle this better
if valid.ndim == 0: valid = valid.unsqueeze(0)
if valid.ndim == 1:
if id_inds.shape[0] == 1:
valid_mult = valid.float().unsqueeze(-1)
else:
valid_mult = valid.float().unsqueeze(0)
else:
valid_mult = valid.float()

valid = norm_dist < max_center_dist # n_k x Np
valid_assn = (
torch.mm(valid.float(), id_inds.to(valid.device))
.clamp_(max=1.0)
.long()
.bool()
torch.mm(valid_mult, id_inds.to(valid.device)).clamp_(max=1.0).long().bool()
) # n_k x M
asso_output_filtered = asso_output.clone()
asso_output_filtered[~valid_assn] = 0 # n_k x M
Expand Down
22 changes: 18 additions & 4 deletions dreem/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dreem.models import GTRRunner
from omegaconf import DictConfig
from pathlib import Path
from datetime import datetime

import hydra
import os
Expand All @@ -14,9 +15,20 @@
import sleap_io as sio
import logging


logger = logging.getLogger("dreem.inference")


def get_timestamp() -> str:
"""Get current timestamp.

Returns:
the current timestamp in /m/d/y-H:M:S format
"""
date_time = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
return date_time


def export_trajectories(
frames_pred: list["dreem.io.Frame"], save_path: str | None = None
) -> pd.DataFrame:
Expand Down Expand Up @@ -117,9 +129,9 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:

checkpoint = pred_cfg.cfg.ckpt_path

logger.info(f"Running inference with model from {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)
logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())

model = GTRRunner.load_from_checkpoint(checkpoint, strict=False)
tracker_cfg = pred_cfg.get_tracker_cfg()

model.tracker_cfg = tracker_cfg
Expand All @@ -140,8 +152,10 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
)
dataloader = pred_cfg.get_dataloader(dataset, mode="test")
preds = track(model, trainer, dataloader)
outpath = os.path.join(outdir, f"{Path(label_file).stem}.dreem_inference.slp")
logger.info(f"Saving results to {outpath}...")
outpath = os.path.join(
outdir, f"{Path(label_file).stem}.dreem_inference.{get_timestamp()}.slp"
)

preds.save(outpath)

return preds
Expand Down
Loading
Loading