Skip to content

Commit

Permalink
callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Jordan-Pierce committed Jan 3, 2025
1 parent ae3b957 commit dd18482
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 14 deletions.
18 changes: 14 additions & 4 deletions tests/test_yolo_tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

"""Tests for `yolo_tiler` package."""

import yolo_tiler
from yolo_tiler import YoloTiler, TileConfig, TileProgress


def progress_callback(progress: TileProgress):
print(f"Processing {progress.current_image} in {progress.current_set} set: "
f"tile {progress.current_tile}/{progress.total_tiles}")


src = "./tests/segmentation"
dst = "./tests/segmentation_tiled"

config = yolo_tiler.TileConfig(
config = TileConfig(
slice_wh=(640, 480), # Slice width and height
overlap_wh=(0.1, 0.1), # Overlap width and height (10% overlap in this example, or 64x48 pixels)
ext=".png",
Expand All @@ -18,10 +24,14 @@
margins=(10, 10, 10, 10), # Left, top, right, bottom
)

tiler = yolo_tiler.YoloTiler(

# Create tiler with callback
tiler = YoloTiler(
source=src,
target=dst,
config=config,
callback=progress_callback
)

tiler.run()
# Run tiling process
tiler.run()
2 changes: 1 addition & 1 deletion yolo_tiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Top-level package for yolo-tiling."""

from yolo_tiler.yolo_tiler import YoloTiler, TileConfig
from yolo_tiler.yolo_tiler import YoloTiler, TileConfig, TileProgress

__version__ = "0.0.3"
__author__ = "Jordan Pierce"
Expand Down
75 changes: 66 additions & 9 deletions yolo_tiler/yolo_tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import logging
import math
import random
from dataclasses import dataclass
from pathlib import Path
from shutil import copyfile
from typing import List, Tuple, Optional, Union, Generator

from dataclasses import dataclass
from typing import List, Tuple, Optional, Union, Generator, Callable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -86,6 +87,15 @@ def get_effective_area(self, image_width: int, image_height: int) -> Tuple[int,
return x_min, y_min, x_max, y_max


@dataclass
class TileProgress:
"""Data class to track tiling progress"""
current_tile: int
total_tiles: int
current_set: str
current_image: str


class YoloTiler:
"""
A class to tile YOLO dataset images and their corresponding annotations.
Expand All @@ -95,18 +105,21 @@ class YoloTiler:
def __init__(self,
source: Union[str, Path],
target: Union[str, Path],
config: TileConfig):
config: TileConfig,
callback: Optional[Callable[[TileProgress], None]] = None):
"""
Initialize YoloTiler with source and target directories.
Args:
source: Source directory containing YOLO dataset
target: Target directory for sliced dataset
config: TileConfig object containing tiling parameters
callback: Optional callback function to report progress
"""
self.source = Path(source)
self.target = Path(target)
self.config = config
self.callback = callback
self.logger = self._setup_logger()

self.subfolders = ['train/',
Expand Down Expand Up @@ -145,6 +158,22 @@ def _validate_yolo_structure(self, folder: Path) -> None:
if not (folder / subfolder / 'labels').exists():
raise ValueError(f"Required folder {folder / subfolder / 'labels'} does not exist")

def _count_total_tiles(self, image_size: Tuple[int, int]) -> int:
"""Count total number of tiles for an image"""
img_w, img_h = image_size
slice_w, slice_h = self.config.slice_wh
overlap_w, overlap_h = self.config.overlap_wh

# Calculate effective step sizes
step_w = self._calculate_step_size(slice_w, overlap_w)
step_h = self._calculate_step_size(slice_h, overlap_h)

# Generate tile positions using numpy for faster calculations
x_coords = self._generate_tile_positions(img_w, step_w)
y_coords = self._generate_tile_positions(img_h, step_h)

return len(x_coords) * len(y_coords)

def _calculate_step_size(self, slice_size: int, overlap: Union[int, float]) -> int:
"""Calculate effective step size for tiling."""
if isinstance(overlap, float):
Expand Down Expand Up @@ -178,10 +207,6 @@ def _calculate_tile_positions(self,
step_w = self._calculate_step_size(slice_w, overlap_w)
step_h = self._calculate_step_size(slice_h, overlap_h)

# Calculate number of tiles in each dimension
num_tiles_w = self._calculate_num_tiles(img_w, step_w)
num_tiles_h = self._calculate_num_tiles(img_h, step_h)

# Generate tile positions using numpy for faster calculations
x_coords = self._generate_tile_positions(img_w, step_w)
y_coords = self._generate_tile_positions(img_h, step_h)
Expand Down Expand Up @@ -329,6 +354,9 @@ def tile_image(self, image_path: Path, label_path: Path, folder: str) -> None:
effective_width = x_max - x_min
effective_height = y_max - y_min

# Calculate total tiles for progress tracking
total_tiles = self._count_total_tiles((effective_width, effective_height))

# Process annotations
boxes = []
with open(label_path) as f:
Expand Down Expand Up @@ -360,6 +388,17 @@ def tile_image(self, image_path: Path, label_path: Path, folder: str) -> None:
# Process each tile
for tile_idx, (x1, y1, x2, y2) in enumerate(self._calculate_tile_positions((effective_width,
effective_height))):

# Report progress if callback is provided
if self.callback:
progress = TileProgress(
current_tile=tile_idx + 1,
total_tiles=total_tiles,
current_set=folder.rstrip('/'),
current_image=image_path.name
)
self.callback(progress)

window = Window(x1 + x_min, y1 + y_min, x2 - x1, y2 - y1)
tile_data = src.read(window=window)
tile_polygon = Polygon([(x1 + x_min, y1 + y_min),
Expand Down Expand Up @@ -477,14 +516,32 @@ def split_data(self) -> None:

valid_set = combined[num_train:num_train + num_valid]
test_set = combined[num_train + num_valid:]
num_test = len(test_set)

# Move files to valid folder
for image_path, label_path in valid_set:
for tile_idx, (image_path, label_path) in enumerate(valid_set):
self._move_split_data(image_path, label_path, 'valid')

if self.callback:
progress = TileProgress(
current_tile=tile_idx+1,
total_tiles=num_valid,
current_set='valid',
current_image=image_path.name
)
self.callback(progress)

# Move files to test folder
for image_path, label_path in test_set:
for tile_idx, (image_path, label_path) in enumerate(test_set):
self._move_split_data(image_path, label_path, 'test')
if self.callback:
progress = TileProgress(
current_tile=tile_idx+1,
total_tiles=num_test,
current_set='test',
current_image=image_path.name
)
self.callback(progress)

def _move_split_data(self, image_path: Path, label_path: Path, folder: str) -> None:
"""
Expand Down

0 comments on commit dd18482

Please sign in to comment.