From da6c84739bf98496d66f90f5835ffb06464aed9a Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 16 Jan 2025 20:55:30 +0100 Subject: [PATCH] Finsih micro-sam training iteration --- scripts/training/README.md | 2 +- scripts/training/train_micro_sam.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/scripts/training/README.md b/scripts/training/README.md index bc498d4..d43ddad 100644 --- a/scripts/training/README.md +++ b/scripts/training/README.md @@ -14,4 +14,4 @@ They will load the **image data** according to the following rules: The training script will save the trained model in `checkpoints/cochlea_distance_unet_`, e.g. `checkpoints/cochlea_distance_unet_20250115`. For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`. -The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/`. +The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/cochlea_micro_sam_`. diff --git a/scripts/training/train_micro_sam.py b/scripts/training/train_micro_sam.py index 0660b79..c94b18a 100644 --- a/scripts/training/train_micro_sam.py +++ b/scripts/training/train_micro_sam.py @@ -2,7 +2,6 @@ from datetime import datetime import numpy as np -import torch_em from micro_sam.training import default_sam_loader, train_sam from train_distance_unet import get_image_and_label_paths, select_paths @@ -36,7 +35,7 @@ def main(): root = args.root run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name - name = f"cochlea_distance_unet_{run_name}" + name = f"cochlea_micro_sam_{run_name}" n_objects_per_batch = args.n_objects_per_batch image_paths, label_paths = get_image_and_label_paths(root) @@ -44,20 +43,19 @@ def main(): val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True) patch_shape = (1, 256, 256) - sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=10) max_sampling_attempts = 2500 train_loader = default_sam_loader( raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None, patch_shape=patch_shape, with_segmentation_decoder=True, - raw_transform=raw_transform, sampler=sampler, min_size=10, + raw_transform=raw_transform, num_workers=6, batch_size=1, is_train=True, max_sampling_attempts=max_sampling_attempts, ) val_loader = default_sam_loader( raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None, patch_shape=patch_shape, with_segmentation_decoder=True, - raw_transform=raw_transform, sampler=sampler, min_size=10, + raw_transform=raw_transform, num_workers=6, batch_size=1, is_train=False, max_sampling_attempts=max_sampling_attempts, ) @@ -65,6 +63,7 @@ def main(): train_sam( name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader, n_epochs=50, n_objects_per_batch=n_objects_per_batch, + save_root=".", )