Skip to content

Commit

Permalink
Update TileProgress class and related code
Browse files Browse the repository at this point in the history
Update the `TileProgress` class and related code to include new attributes.

* **TileProgress Class:**
  - Add attributes: `current_set_name`, `current_image_name`, `current_image_idx`, `total_images`, `current_tile_idx`.
  - Remove attributes: `current_tile`, `current_set`, `current_image`.

* **YoloTiler Class:**
  - Update `_tqdm_callback` method to handle new attributes.
  - Modify progress updates to use `current_set_name` and `current_tile_idx`.
  - Update progress callback invocations to include new attributes with placeholders for `current_image_idx` and `total_images`.

* **Tests:**
  - Update `progress_callback` function in `tests/test_yolo_tiler.py` to handle new attributes.
  • Loading branch information
Jordan-Pierce committed Jan 14, 2025
1 parent 9e4ee58 commit 24b8d98
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
4 changes: 2 additions & 2 deletions tests/test_yolo_tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


def progress_callback(progress: TileProgress):
print(f"Processing {progress.current_image} in {progress.current_set} set: "
f"tile {progress.current_tile}/{progress.total_tiles}")
print(f"Processing {progress.current_image_name} in {progress.current_set_name} set: "
f"tile {progress.current_tile_idx}/{progress.total_tiles}")


src = "./tests/segmentation"
Expand Down
56 changes: 33 additions & 23 deletions yolo_tiler/yolo_tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ def get_effective_area(self, image_width: int, image_height: int) -> Tuple[int,
@dataclass
class TileProgress:
"""Data class to track tiling progress"""
current_tile: int
current_set_name: str
current_image_name: str
current_image_idx: int
total_images: int
current_tile_idx: int
total_tiles: int
current_set: str
current_image: str


class YoloTiler:
Expand Down Expand Up @@ -164,21 +166,21 @@ def _tqdm_callback(self, progress: TileProgress):
progress: TileProgress object containing current progress
"""
if progress.current_set not in self._progress_bars:
self._progress_bars[progress.current_set] = tqdm(
if progress.current_set_name not in self._progress_bars:
self._progress_bars[progress.current_set_name] = tqdm(
total=progress.total_tiles,
desc=progress.current_set,
desc=progress.current_set_name,
unit='items'
)

# Update progress
self._progress_bars[progress.current_set].n = progress.current_tile
self._progress_bars[progress.current_set].refresh()
self._progress_bars[progress.current_set_name].n = progress.current_tile_idx
self._progress_bars[progress.current_set_name].refresh()

# Close and cleanup if task is complete
if progress.current_tile >= progress.total_tiles:
self._progress_bars[progress.current_set].close()
del self._progress_bars[progress.current_set]
if progress.current_tile_idx >= progress.total_tiles:
self._progress_bars[progress.current_set_name].close()
del self._progress_bars[progress.current_set_name]

def _setup_logger(self) -> logging.Logger:
"""Configure logging for the tiler"""
Expand Down Expand Up @@ -524,10 +526,12 @@ def clean_geometry(geom: Polygon) -> Polygon:

if self.progress_callback:
progress = TileProgress(
current_tile=tile_idx + 1,
current_tile_idx=tile_idx + 1,
total_tiles=total_tiles,
current_set=folder.rstrip('/'),
current_image=image_path.name
current_set_name=folder.rstrip('/'),
current_image_name=image_path.name,
current_image_idx=0, # Placeholder, update as needed
total_images=0 # Placeholder, update as needed
)
self.progress_callback(progress)

Expand Down Expand Up @@ -668,10 +672,12 @@ def split_data(self) -> None:

if self.progress_callback:
progress = TileProgress(
current_tile=tile_idx + 1,
current_tile_idx=tile_idx + 1,
total_tiles=num_valid,
current_set='valid',
current_image=image_path.name
current_set_name='valid',
current_image_name=image_path.name,
current_image_idx=0, # Placeholder, update as needed
total_images=0 # Placeholder, update as needed
)
self.progress_callback(progress)

Expand All @@ -680,10 +686,12 @@ def split_data(self) -> None:
self._move_split_data(image_path, label_path, 'test')
if self.progress_callback:
progress = TileProgress(
current_tile=tile_idx + 1,
current_tile_idx=tile_idx + 1,
total_tiles=num_test,
current_set='test',
current_image=image_path.name
current_set_name='test',
current_image_name=image_path.name,
current_image_idx=0, # Placeholder, update as needed
total_images=0 # Placeholder, update as needed
)
self.progress_callback(progress)

Expand Down Expand Up @@ -778,10 +786,12 @@ def visualize_random_samples(self) -> None:

if self.progress_callback:
progress = TileProgress(
current_tile=tile_idx + 1,
current_tile_idx=tile_idx + 1,
total_tiles=num_samples,
current_set='rendered',
current_image=image_path.name
current_set_name='rendered',
current_image_name=image_path.name,
current_image_idx=0, # Placeholder, update as needed
total_images=0 # Placeholder, update as needed
)
self.progress_callback(progress)

Expand Down

0 comments on commit 24b8d98

Please sign in to comment.