Skip to content

Commit

Permalink
Finsih micro-sam training iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 16, 2025
1 parent 2fd0344 commit da6c847
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion scripts/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ They will load the **image data** according to the following rules:
The training script will save the trained model in `checkpoints/cochlea_distance_unet_<CURRENT_DATE>`, e.g. `checkpoints/cochlea_distance_unet_20250115`.
For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`.

The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/`.
The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/cochlea_micro_sam_<CURRENT_DATE>`.
9 changes: 4 additions & 5 deletions scripts/training/train_micro_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from datetime import datetime

import numpy as np
import torch_em
from micro_sam.training import default_sam_loader, train_sam
from train_distance_unet import get_image_and_label_paths, select_paths

Expand Down Expand Up @@ -36,35 +35,35 @@ def main():

root = args.root
run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name
name = f"cochlea_distance_unet_{run_name}"
name = f"cochlea_micro_sam_{run_name}"
n_objects_per_batch = args.n_objects_per_batch

image_paths, label_paths = get_image_and_label_paths(root)
train_image_paths, train_label_paths = select_paths(image_paths, label_paths, split="train", filter_empty=True)
val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True)

patch_shape = (1, 256, 256)
sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=10)
max_sampling_attempts = 2500

train_loader = default_sam_loader(
raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None,
patch_shape=patch_shape, with_segmentation_decoder=True,
raw_transform=raw_transform, sampler=sampler, min_size=10,
raw_transform=raw_transform,
num_workers=6, batch_size=1, is_train=True,
max_sampling_attempts=max_sampling_attempts,
)
val_loader = default_sam_loader(
raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None,
patch_shape=patch_shape, with_segmentation_decoder=True,
raw_transform=raw_transform, sampler=sampler, min_size=10,
raw_transform=raw_transform,
num_workers=6, batch_size=1, is_train=False,
max_sampling_attempts=max_sampling_attempts,
)

train_sam(
name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader,
n_epochs=50, n_objects_per_batch=n_objects_per_batch,
save_root=".",
)


Expand Down

0 comments on commit da6c847

Please sign in to comment.