Skip to content

Commit

Permalink
Training updates (#17)
Browse files Browse the repository at this point in the history
* Update training scripts

* Add README for unet training
  • Loading branch information
constantinpape authored Jan 15, 2025
1 parent 87688cd commit 170e449
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 68 deletions.
1 change: 1 addition & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ conda install -c conda-forge mobie_utils
## Training

Contains the scripts for training a U-Net that predicts foreground probabilties and normalized object distances.
It also contains a documentation for how to run training on new annotated data.


## Prediction
Expand Down
7 changes: 7 additions & 0 deletions scripts/data_transfer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@ Try to automate via https://github.com/jborean93/smbprotocol see `sync_smb.py` f

For transfering back MoBIE results.
...

# Data Transfer Huisken

See "Transfer via smbclient" above:
```
smbclient \\\\wfs-biologie-spezial.top.gwdg.de\\UBM1-all\$\\ -U GWDG\\pape41
```
14 changes: 14 additions & 0 deletions scripts/training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# 3D U-Net Training for Cochlea Data

This folder contains the scripts for training a 3D U-Net for cell segmentation in the cochlea data.
It contains two relevant scripts:
- `check_training_data.py`, which visualizes the training data and annotations in napari.
- `train_distance_unet.py`, which trains the 3D U-Net.

Both scripts accept the argument `-i /path/to/data`, to specify the root folder with the training data. For example, run `python train_distance_unet.py -i /path/to/data` for training. The scripts will consider all tif files in the sub-folders of the root folder for training.
They will load the **image data** according to the following rules:
- Files with the ending `_annotations.tif` or `_cp_masks.tif` will not be considered as image data.
- The other files will be considered as image data, if a corresponding file with ending `_annotations.tif` can be found. If it cannot be found the file will be excluded; the scripts will print the name of all files being excluded.

The training script will save the trained model in `checkpoints/cochlea_distance_unet_<CURRENT_DATE>`, e.g. `checkpoints/cochlea_distance_unet_20250115`.
For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`.
59 changes: 37 additions & 22 deletions scripts/training/check_training_data.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,57 @@
import argparse
import os
from glob import glob

import imageio.v3 as imageio
import napari
import numpy as np

root = "/home/pape/Work/data/moser/lightsheet"
from train_distance_unet import get_image_and_label_paths
from tqdm import tqdm

# Root folder on my laptop.
# This is just for convenience, so that I don't have to pass
# the root argument during development.
ROOT_CP = "/home/pape/Work/data/moser/lightsheet"

def check_visually(check_downsampled=False):
if check_downsampled:
images = sorted(glob(os.path.join(root, "images_s2", "*.tif")))
masks = sorted(glob(os.path.join(root, "masks_s2", "*.tif")))
else:
images = sorted(glob(os.path.join(root, "images", "*.tif")))
masks = sorted(glob(os.path.join(root, "masks", "*.tif")))
assert len(images) == len(masks)

for im, mask in zip(images, masks):
print(im)
def check_visually(images, labels):
for im, label in tqdm(zip(images, labels), total=len(images)):

vol = imageio.imread(im)
seg = imageio.imread(mask).astype("uint32")
seg = imageio.imread(label).astype("uint32")

v = napari.Viewer()
v.add_image(vol)
v.add_labels(seg)
v.add_image(vol, name="pv-channel")
v.add_labels(seg, name="annotations")
folder, name = os.path.split(im)
folder = os.path.basename(folder)
v.title = f"{folder}/{name}"
napari.run()


def check_labels():
masks = sorted(glob(os.path.join(root, "masks", "*.tif")))
for mask_path in masks:
labels = imageio.imread(mask_path)
def check_labels(images, labels):
for label_path in labels:
labels = imageio.imread(label_path)
n_labels = len(np.unique(labels))
print(mask_path, n_labels)
print(label_path, n_labels)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--root", "-i", help="The root folder with the annotated training crops.",
default=ROOT_CP,
)
parser.add_argument("--check_labels", "-l", action="store_true")
args = parser.parse_args()
root = args.root

images, labels = get_image_and_label_paths(root)

check_visually(images, labels)
if args.check_labels:
check_labels(images, labels)


if __name__ == "__main__":
check_visually(True)
# check_labels()
main()
124 changes: 78 additions & 46 deletions scripts/training/train_distance_unet.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,72 @@
import argparse
import os
from datetime import datetime
from glob import glob

import torch_em

from torch_em.model import UNet3d

# DATA_ROOT = "/home/pape/Work/data/moser/lightsheet"
DATA_ROOT = "/scratch-grete/usr/nimcpape/data/moser/lightsheet"
ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"


def get_image_and_label_paths(root):
exclude_names = ["annotations", "cp_masks"]
all_image_paths = sorted(glob(os.path.join(root, "**/**.tif"), recursive=True))
all_image_paths = [
path for path in all_image_paths if not any(exclude in path for exclude in exclude_names)
]

image_paths, label_paths = [], []
label_extensions = ["_annotations.tif"]
for path in all_image_paths:
folder, fname = os.path.split(path)
fname = os.path.splitext(fname)[0]
label_path = None
for ext in label_extensions:
candidate_label_path = os.path.join(folder, f"{fname}{ext}")
if os.path.exists(candidate_label_path):
label_path = candidate_label_path
break

if label_path is None:
print("Did not find annotations for", path)
print("This image will not be used for training.")
else:
image_paths.append(path)
label_paths.append(label_path)

assert len(image_paths) == len(label_paths)
return image_paths, label_paths


def get_paths(image_paths, label_paths, split, filter_empty):
def select_paths(image_paths, label_paths, split, filter_empty):
if filter_empty:
image_paths = [imp for imp in image_paths if "empty" not in imp]
label_paths = [imp for imp in label_paths if "empty" not in imp]
assert len(image_paths) == len(label_paths)

n_files = len(image_paths)

train_fraction = 0.8
val_fraction = 0.1
train_fraction = 0.85

n_train = int(train_fraction * n_files)
n_val = int(val_fraction * n_files)
if split == "train":
image_paths = image_paths[:n_train]
label_paths = label_paths[:n_train]

elif split == "val":
image_paths = image_paths[n_train:(n_train + n_val)]
label_paths = label_paths[n_train:(n_train + n_val)]
image_paths = image_paths[n_train:]
label_paths = label_paths[n_train:]

return image_paths, label_paths


def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"]):
image_paths, label_paths = [], []

if "default" in train_on:
all_image_paths = sorted(glob(os.path.join(DATA_ROOT, "images", "*.tif")))
all_label_paths = sorted(glob(os.path.join(DATA_ROOT, "masks", "*.tif")))
this_image_paths, this_label_paths = get_paths(all_image_paths, all_label_paths, split, filter_empty)
image_paths.extend(this_image_paths)
label_paths.extend(this_label_paths)
def get_loader(root, split, patch_shape, batch_size, filter_empty):
image_paths, label_paths = get_image_and_label_paths(root)
this_image_paths, this_label_paths = select_paths(image_paths, label_paths, split, filter_empty)

if "downsampled" in train_on:
all_image_paths = sorted(glob(os.path.join(DATA_ROOT, "images_s2", "*.tif")))
all_label_paths = sorted(glob(os.path.join(DATA_ROOT, "masks_s2", "*.tif")))
this_image_paths, this_label_paths = get_paths(all_image_paths, all_label_paths, split, filter_empty)
image_paths.extend(this_image_paths)
label_paths.extend(this_label_paths)
assert len(this_image_paths) == len(this_label_paths)
assert len(this_image_paths) > 0

label_transform = torch_em.transform.label.PerObjectDistanceTransform(
distances=True, boundary_distances=True, foreground=True,
Expand All @@ -59,7 +77,7 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"
elif split == "val":
n_samples = 20 * batch_size

sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.95)
sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8)
loader = torch_em.default_segmentation_loader(
raw_paths=image_paths, raw_key=None, label_paths=label_paths, label_key=None,
batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform,
Expand All @@ -69,26 +87,45 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"
return loader


def main(check_loaders=False):
# Parameters for training:
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--root", "-i", help="The root folder with the annotated training crops.",
default=ROOT_CLUSTER,
)
parser.add_argument(
"--batch_size", "-b", help="The batch size for training. Set to 8 by default."
"You may need to choose a smaller batch size to train on yoru GPU.",
default=8, type=int,
)
parser.add_argument(
"--check_loaders", "-l", action="store_true",
help="Visualize the data loader output instead of starting a training run."
)
parser.add_argument(
"--filter_empty", "-f", action="store_true",
help="Whether to exclude blocks with empty annotations from the training process."
)
parser.add_argument(
"--name", help="Optional name for the model to be trained. If not given the current date is used."
)
args = parser.parse_args()
root = args.root
batch_size = args.batch_size
check_loaders = args.check_loaders
filter_empty = args.filter_empty
run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name

# Parameters for training on A100.
n_iterations = 1e5
batch_size = 8
filter_empty = False
train_on = ["downsampled"]
# train_on = ["downsampled", "default"]

patch_shape = (32, 128, 128) if "downsampled" in train_on else (64, 128, 128)
patch_shape = (64, 128, 128)

# The U-Net.
model = UNet3d(in_channels=1, out_channels=3, initial_features=32, final_activation="Sigmoid")

# Create the training loader with train and val set.
train_loader = get_loader(
"train", patch_shape, batch_size, filter_empty=filter_empty, train_on=train_on
)
val_loader = get_loader(
"val", patch_shape, batch_size, filter_empty=filter_empty, train_on=train_on
)
train_loader = get_loader(root, "train", patch_shape, batch_size, filter_empty=filter_empty)
val_loader = get_loader(root, "val", patch_shape, batch_size, filter_empty=filter_empty)

if check_loaders:
from torch_em.util.debug import check_loader
Expand All @@ -99,12 +136,7 @@ def main(check_loaders=False):
loss = torch_em.loss.distance_based.DiceBasedDistanceLoss(mask_distances_in_bg=True)

# Create the trainer.
name = "cochlea_distance_unet"
if filter_empty:
name += "-filter-empty"
if train_on == ["downsampled"]:
name += "-train-downsampled"

name = f"cochlea_distance_unet_{run_name}"
trainer = torch_em.default_segmentation_trainer(
name=name,
model=model,
Expand All @@ -123,4 +155,4 @@ def main(check_loaders=False):


if __name__ == "__main__":
main(check_loaders=False)
main()

0 comments on commit 170e449

Please sign in to comment.