From d3c40f1f580c3e587efa02289cabf87c1159daf1 Mon Sep 17 00:00:00 2001 From: Jordan Pierce <115024024+Jordan-Pierce@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:49:02 -0500 Subject: [PATCH] Clamp coords, clip polygons, render samples --- requirements.txt | 1 + tests/test_yolo_tiler.py | 7 +- yolo_tiler/yolo_tiler.py | 320 +++++++++++++++++++++++++++++++-------- 3 files changed, 258 insertions(+), 70 deletions(-) diff --git a/requirements.txt b/requirements.txt index d0a3384..682d86d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy pandas Shapely rasterio +opencv-python diff --git a/tests/test_yolo_tiler.py b/tests/test_yolo_tiler.py index c343c02..30d0ce7 100644 --- a/tests/test_yolo_tiler.py +++ b/tests/test_yolo_tiler.py @@ -15,13 +15,13 @@ def progress_callback(progress: TileProgress): config = TileConfig( slice_wh=(640, 480), # Slice width and height - overlap_wh=(0.1, 0.1), # Overlap width and height (10% overlap in this example, or 64x48 pixels) + overlap_wh=(0.0, 0.0), # Overlap width and height (10% overlap in this example, or 64x48 pixels) ext=".png", annotation_type="instance_segmentation", train_ratio=0.7, valid_ratio=0.2, test_ratio=0.1, - margins=(10, 10, 10, 10), # Left, top, right, bottom + margins=(0, 0, 0, 0), # Left, top, right, bottom ) @@ -30,8 +30,9 @@ def progress_callback(progress: TileProgress): source=src, target=dst, config=config, + num_viz_samples=15, callback=progress_callback ) # Run tiling process -tiler.run() \ No newline at end of file +tiler.run() diff --git a/yolo_tiler/yolo_tiler.py b/yolo_tiler/yolo_tiler.py index 6556625..cf35b35 100644 --- a/yolo_tiler/yolo_tiler.py +++ b/yolo_tiler/yolo_tiler.py @@ -1,16 +1,19 @@ -import warnings import logging import math import random +import warnings +from dataclasses import dataclass from pathlib import Path from shutil import copyfile - -from dataclasses import dataclass from typing import List, Tuple, Optional, Union, Generator, Callable +import cv2 +import matplotlib.patches as patches +import matplotlib.pyplot as plt import numpy as np import pandas as pd import rasterio +from matplotlib.patches import Polygon as MplPolygon from rasterio.windows import Window from shapely.geometry import Polygon, MultiPolygon @@ -106,6 +109,7 @@ def __init__(self, source: Union[str, Path], target: Union[str, Path], config: TileConfig, + num_viz_samples: int = 0, callback: Optional[Callable[[TileProgress], None]] = None): """ Initialize YoloTiler with source and target directories. @@ -114,6 +118,7 @@ def __init__(self, source: Source directory containing YOLO dataset target: Target directory for sliced dataset config: TileConfig object containing tiling parameters + num_viz_samples: Number of random samples to visualize from train set callback: Optional callback function to report progress """ self.source = Path(source) @@ -121,10 +126,13 @@ def __init__(self, self.config = config self.callback = callback self.logger = self._setup_logger() + self.num_viz_samples = num_viz_samples + self.subfolders = ['train/', 'valid/', 'test/'] - self.subfolders = ['train/', - 'valid/', - 'test/'] + # Create rendered directory if visualization is requested + if self.num_viz_samples > 0: + self.render_dir = self.target / 'rendered' + self.render_dir.mkdir(parents=True, exist_ok=True) def _setup_logger(self) -> logging.Logger: """Configure logging for the tiler""" @@ -191,13 +199,13 @@ def _generate_tile_positions(self, img_size: int, step_size: int) -> np.ndarray: def _calculate_tile_positions(self, image_size: Tuple[int, int]) -> Generator[Tuple[int, int, int, int], None, None]: """ - Calculate tile positions with overlap. + Calculate tile positions with overlap, respecting margins. Args: - image_size: (width, height) of the image + image_size: (width, height) of the image after margins applied Yields: - Tuples of (x1, y1, x2, y2) for each tile + Tuples of (x1, y1, x2, y2) for each tile within effective area """ img_w, img_h = image_size slice_w, slice_h = self.config.slice_wh @@ -208,6 +216,7 @@ def _calculate_tile_positions(self, step_h = self._calculate_step_size(slice_h, overlap_h) # Generate tile positions using numpy for faster calculations + # Use effective dimensions (after margins) x_coords = self._generate_tile_positions(img_w, step_w) y_coords = self._generate_tile_positions(img_h, step_h) @@ -272,51 +281,52 @@ def _process_polygon(self, poly: Polygon) -> List[List[Tuple[float, float]]]: return result def _process_intersection(self, intersection: Union[Polygon, MultiPolygon]) -> List[List[Tuple[float, float]]]: - """ - Process intersection geometry with improved quality. + """Process intersection geometry with proper polygon closure.""" + def process_single_polygon(poly: Polygon) -> List[List[Tuple[float, float]]]: + # Ensure proper closure of exterior ring + exterior_coords = list(poly.exterior.coords) + if exterior_coords[0] != exterior_coords[-1]: + exterior_coords.append(exterior_coords[0]) - Args: - intersection: Shapely geometry object + result = [exterior_coords[:-1]] # Remove duplicate closing point + + # Process holes with proper closure + for interior in poly.interiors: + interior_coords = list(interior.coords) + if interior_coords[0] != interior_coords[-1]: + interior_coords.append(interior_coords[0]) + result.append(interior_coords[:-1]) + + return result - Returns: - List of coordinate lists (exterior + holes) - """ if isinstance(intersection, Polygon): - return self._process_polygon(intersection) + return process_single_polygon(intersection) else: # MultiPolygon all_coords = [] - # Process all polygons, not just the largest for poly in intersection.geoms: - all_coords.extend(self._process_polygon(poly)) + all_coords.extend(process_single_polygon(poly)) return all_coords - def _normalize_coordinates(self, - coord_lists: List[List[Tuple[float, float]]], - tile_bounds: Tuple[int, int, int, int]) -> str: - """ - Normalize coordinates to [0,1] range relative to tile bounds. - - Args: - coord_lists: List of coordinate lists (exterior + holes) - tile_bounds: (x1, y1, x2, y2) of tile bounds - - Returns: - Space-separated string of normalized coordinates - """ + def _normalize_coordinates(self, coord_lists: List[List[Tuple[float, float]]], + tile_bounds: Tuple[int, int, int, int]) -> str: + """Normalize coordinates with proper polygon closure.""" x1, y1, x2, y2 = tile_bounds tile_width = x2 - x1 tile_height = y2 - y1 normalized_parts = [] for coords in coord_lists: + # Ensure proper closure + if coords[0] != coords[-1]: + coords = coords + [coords[0]] + normalized = [] for x, y in coords: - norm_x = (x - x1) / tile_width - norm_y = (y - y1) / tile_height + norm_x = max(0, min(1, (x - x1) / tile_width)) # Clamp to [0,1] + norm_y = max(0, min(1, (y - y1) / tile_height)) normalized.append(f"{norm_x:.6f} {norm_y:.6f}") normalized_parts.append(normalized) - # Join all parts with special separator return " ".join([" ".join(part) for part in normalized_parts]) def _save_labels(self, labels: List, path: Path, is_segmentation: bool) -> None: @@ -338,22 +348,25 @@ def _save_labels(self, labels: List, path: Path, is_segmentation: bool) -> None: def tile_image(self, image_path: Path, label_path: Path, folder: str) -> None: """ - Tile an image and its corresponding labels. - - Args: - image_path: Path to image file - label_path: Path to label file - folder: Subfolder name (train, valid, test) + Tile an image and its corresponding labels, properly handling margins. """ # Read image and labels with rasterio.open(image_path) as src: width, height = src.width, src.height - # Get effective area + # Get effective area (area after margins applied) x_min, y_min, x_max, y_max = self.config.get_effective_area(width, height) effective_width = x_max - x_min effective_height = y_max - y_min + # Create polygon representing effective area (excludes margins) + effective_area = Polygon([ + (x_min, y_min), + (x_max, y_min), + (x_max, y_max), + (x_min, y_max) + ]) + # Calculate total tiles for progress tracking total_tiles = self._count_total_tiles((effective_width, effective_height)) @@ -365,31 +378,57 @@ def tile_image(self, image_path: Path, label_path: Path, folder: str) -> None: class_id = int(parts[0]) if self.config.annotation_type == "object_detection": - # Parse fixed format: class x y w h - x_center = float(parts[1]) * width - y_center = float(parts[2]) * height - box_w = float(parts[3]) * width - box_h = float(parts[4]) * height + # Parse normalized coordinates + x_center_norm = float(parts[1]) + y_center_norm = float(parts[2]) + box_w_norm = float(parts[3]) + box_h_norm = float(parts[4]) + + # Convert to absolute coordinates + x_center = x_center_norm * width + y_center = y_center_norm * height + box_w = box_w_norm * width + box_h = box_h_norm * height x1 = x_center - box_w / 2 y1 = y_center - box_h / 2 x2 = x_center + box_w / 2 y2 = y_center + box_h / 2 - boxes.append((class_id, Polygon([(x1, y1), (x2, y1), (x2, y2), (x1, y2)]))) + box_polygon = Polygon([(x1, y1), (x2, y1), (x2, y2), (x1, y2)]) + + # Only include if box intersects with effective area + if box_polygon.intersects(effective_area): + # Clip box to effective area + clipped_box = box_polygon.intersection(effective_area) + if not clipped_box.is_empty: + boxes.append((class_id, clipped_box)) else: - # Parse variable length format: class x1 y1 x2 y2 ... + # Instance segmentation points = [] for i in range(1, len(parts), 2): - x = float(parts[i]) * width - y = float(parts[i + 1]) * height + x_norm = float(parts[i]) + y_norm = float(parts[i + 1]) + x = x_norm * width + y = y_norm * height points.append((x, y)) - boxes.append((class_id, Polygon(points))) - # Process each tile - for tile_idx, (x1, y1, x2, y2) in enumerate(self._calculate_tile_positions((effective_width, - effective_height))): + polygon = Polygon(points) + # Only include if polygon intersects with effective area + if polygon.intersects(effective_area): + # Clip polygon to effective area + clipped_polygon = polygon.intersection(effective_area) + if not clipped_polygon.is_empty: + boxes.append((class_id, clipped_polygon)) + + # Process each tile within effective area + for tile_idx, (x1, y1, x2, y2) in enumerate( + self._calculate_tile_positions((effective_width, effective_height))): + # Convert tile coordinates to absolute image coordinates + abs_x1 = x1 + x_min + abs_y1 = y1 + y_min + abs_x2 = x2 + x_min + abs_y2 = y2 + y_min - # Report progress if callback is provided if self.callback: progress = TileProgress( current_tile=tile_idx + 1, @@ -399,12 +438,18 @@ def tile_image(self, image_path: Path, label_path: Path, folder: str) -> None: ) self.callback(progress) - window = Window(x1 + x_min, y1 + y_min, x2 - x1, y2 - y1) + # Extract tile data + window = Window(abs_x1, abs_y1, abs_x2 - abs_x1, abs_y2 - abs_y1) tile_data = src.read(window=window) - tile_polygon = Polygon([(x1 + x_min, y1 + y_min), - (x2 + x_min, y1 + y_min), - (x2 + x_min, y2 + y_min), - (x1 + x_min, y2 + y_min)]) + + # Create polygon for current tile + tile_polygon = Polygon([ + (abs_x1, abs_y1), + (abs_x2, abs_y1), + (abs_x2, abs_y2), + (abs_x1, abs_y2) + ]) + tile_labels = [] # Process annotations for this tile @@ -417,15 +462,21 @@ def tile_image(self, image_path: Path, label_path: Path, folder: str) -> None: bbox = intersection.envelope center = bbox.centroid bbox_coords = bbox.exterior.coords.xy - new_width = (max(bbox_coords[0]) - min(bbox_coords[0])) / (x2 - x1) - new_height = (max(bbox_coords[1]) - min(bbox_coords[1])) / (y2 - y1) - new_x = (center.x - x1) / (x2 - x1) - new_y = (center.y - y1) / (y2 - y1) + + # Normalize relative to tile dimensions + tile_width = abs_x2 - abs_x1 + tile_height = abs_y2 - abs_y1 + + new_width = (max(bbox_coords[0]) - min(bbox_coords[0])) / tile_width + new_height = (max(bbox_coords[1]) - min(bbox_coords[1])) / tile_height + new_x = (center.x - abs_x1) / tile_width + new_y = (center.y - abs_y1) / tile_height + tile_labels.append([box_class, new_x, new_y, new_width, new_height]) else: - # Handle instance segmentation with improved processing + # Handle instance segmentation coord_lists = self._process_intersection(intersection) - normalized = self._normalize_coordinates(coord_lists, (x1, y1, x2, y2)) + normalized = self._normalize_coordinates(coord_lists, (abs_x1, abs_y1, abs_x2, abs_y2)) tile_labels.append([box_class, normalized]) # Save tile image and labels @@ -603,6 +654,135 @@ def _copy_data_yaml(self) -> None: else: self.logger.warning('data.yaml not found in source directory') + def visualize_random_samples(self) -> None: + """ + Visualize random samples from the train folder with their annotations. + """ + if self.num_viz_samples <= 0: + return + + # Get all image paths from train folder + train_image_dir = self.target / 'train' / 'images' + train_label_dir = self.target / 'train' / 'labels' + + image_paths = list(train_image_dir.glob(f'*{self.config.ext}')) + + if not image_paths: + self.logger.warning("No images found in train folder for visualization") + return + + # Select random samples + num_samples = min(self.num_viz_samples, len(image_paths)) + selected_images = random.sample(image_paths, num_samples) + + # Process each selected image + for tile_idx, image_path in enumerate(selected_images): + label_path = train_label_dir / f"{image_path.stem}.txt" + + if not label_path.exists(): + self.logger.warning(f"Label file not found for {image_path.name}") + continue + + if self.callback: + progress = TileProgress( + current_tile=tile_idx+1, + total_tiles=num_samples, + current_set='rendered', + current_image=image_path.name + ) + self.callback(progress) + + self._render_single_sample(image_path, label_path, tile_idx+1) + + def _render_single_sample(self, image_path: Path, label_path: Path, idx: int) -> None: + """ + Render a single sample with its annotations. + + Args: + image_path: Path to the image file + label_path: Path to the label file + idx: Index for the output filename + """ + # Read image using OpenCV + img = cv2.imread(str(image_path)) + if img is None: + self.logger.warning(f"Could not read image: {image_path}") + return + + # Convert BGR to RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + height, width = img.shape[:2] + + # Create figure and axis + fig, ax = plt.subplots(1) + ax.imshow(img) + + # Random colors for different classes + np.random.seed(42) # For consistent colors + colors = np.random.rand(100, 3) # Support up to 100 classes + + # Read and parse labels + with open(label_path) as f: + for line in f: + parts = line.strip().split() + class_id = int(parts[0]) + color = colors[class_id % len(colors)] + + if self.config.annotation_type == "object_detection": + # Parse bounding box + x_center = float(parts[1]) * width + y_center = float(parts[2]) * height + box_w = float(parts[3]) * width + box_h = float(parts[4]) * height + + # Calculate box coordinates + x = x_center - box_w / 2 + y = y_center - box_h / 2 + + # Create rectangle patch with transparency + rect = patches.Rectangle( + (x, y), + box_w, + box_h, + linewidth=2, + edgecolor=color, + facecolor=color, + alpha=0.3 # Add transparency + ) + ax.add_patch(rect) + + else: # instance segmentation + # Parse polygon coordinates + coords = [] + for i in range(1, len(parts), 2): + x = float(parts[i]) * width + y = float(parts[i + 1]) * height + coords.append([x, y]) + + # Create polygon patch with transparency + polygon = MplPolygon( + coords, + facecolor=color, + edgecolor=color, + linewidth=2, + alpha=0.3 # Add transparency + ) + ax.add_patch(polygon) + + # Remove axes + ax.axis('off') + + # Adjust layout + plt.tight_layout() + + # Save the visualization + output_path = self.render_dir / f"sample_{idx:03d}.jpg" + plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=300) + plt.close() + + self.logger.info(f"Saved visualization to {output_path}") + def run(self) -> None: """Run the complete tiling process""" try: @@ -621,6 +801,12 @@ def run(self) -> None: # Copy data.yaml self._copy_data_yaml() + # Generate visualizations if requested + if self.num_viz_samples > 0: + self.logger.info(f'Generating {self.num_viz_samples} visualization samples...') + self.visualize_random_samples() + self.logger.info('Visualization generation completed') + except Exception as e: self.logger.error(f'Error during tiling process: {str(e)}') raise