Skip to content

Commit

Permalink
only start satlas jobs that weren't already completed
Browse files Browse the repository at this point in the history
  • Loading branch information
favyen2 committed Jan 17, 2025
1 parent 31576ab commit 298de6a
Showing 1 changed file with 82 additions and 39 deletions.
121 changes: 82 additions & 39 deletions rslp/satlas/write_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import random
from collections.abc import Generator
from datetime import datetime, timedelta, timezone

import shapely
Expand All @@ -11,10 +12,11 @@
from rslearn.const import WGS84_PROJECTION
from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
from rslearn.utils.get_utm_ups_crs import get_proj_bounds
from upath import UPath

from rslp.log_utils import get_logger

from .predict_pipeline import Application, PredictTaskArgs
from .predict_pipeline import Application, PredictTaskArgs, get_output_fname

logger = get_logger(__name__)

Expand Down Expand Up @@ -54,6 +56,43 @@ def __init__(
self.time_range = time_range
self.out_path = out_path

def get_output_fname(self) -> UPath:
"""Get the output filename that will be used for this task."""
# The filename format is defined by get_output_fname in predict_pipeline.py.
return get_output_fname(
self.application, self.out_path, self.projection, self.bounds
)


def enumerate_tiles_in_zone(utm_zone: CRS) -> Generator[tuple[int, int], None, None]:
"""List all of the tiles in the zone where outputs should be computed.
The tiles are all TILE_SIZE x TILE_SIZE so only the column/row of the tile along
that grid are returned.
Args:
utm_zone: the CRS which must correspond to a UTM EPSG.
Returns:
generator of (column, row) of the tiles that are needed.
"""
# We use get_proj_bounds to get the bounds of the UTM zone in CRS units.
# We then convert to pixel units in order to determine the tiles that are needed.
crs_bbox = STGeometry(
Projection(utm_zone, 1, 1),
shapely.box(*get_proj_bounds(utm_zone)),
None,
)
projection = Projection(utm_zone, RESOLUTION, -RESOLUTION)
pixel_bbox = crs_bbox.to_projection(projection)

# Convert the resulting shape to integer bbox.
zone_bounds = tuple(int(value) for value in pixel_bbox.shp.bounds)

for col in range(zone_bounds[0] // TILE_SIZE, zone_bounds[2] // TILE_SIZE + 1):
for row in range(zone_bounds[1] // TILE_SIZE, zone_bounds[3] // TILE_SIZE + 1):
yield (col, row)


def get_jobs(
application: Application,
Expand All @@ -66,6 +105,8 @@ def get_jobs(
) -> list[list[str]]:
"""Get batches of tasks for Satlas prediction.
Tasks where outputs have already been computed are excluded.
Args:
application: which application to run.
time_range: the time range to run within. Must have timezone.
Expand All @@ -91,17 +132,10 @@ def get_jobs(

tasks: list[Task] = []
for utm_zone in tqdm.tqdm(utm_zones, desc="Enumerating tasks across UTM zones"):
# get_proj_bounds returns bounds in CRS units so we need to convert to pixel
# units.
crs_bbox = STGeometry(
Projection(utm_zone, 1, 1),
shapely.box(*get_proj_bounds(utm_zone)),
None,
)
projection = Projection(utm_zone, RESOLUTION, -RESOLUTION)
pixel_bbox = crs_bbox.to_projection(projection)
zone_bounds = tuple(int(value) for value in pixel_bbox.shp.bounds)

# If the user provided WGS84 bounds, then we convert it to pixel coordinates so
# we can check each tile easily.
user_bounds_in_proj: PixelBounds | None = None
if wgs84_bounds is not None:
dst_geom = STGeometry(
Expand All @@ -114,42 +148,51 @@ def get_jobs(
int(dst_geom.shp.bounds[3]),
)

for col in range(zone_bounds[0] // TILE_SIZE, zone_bounds[2] // TILE_SIZE + 1):
for row in range(
zone_bounds[1] // TILE_SIZE, zone_bounds[3] // TILE_SIZE + 1
):
if user_bounds_in_proj is not None:
# Check if this task intersects the bounds specified by the user.
if (col + 1) * TILE_SIZE < user_bounds_in_proj[0]:
continue
if col * TILE_SIZE >= user_bounds_in_proj[2]:
continue
if (row + 1) * TILE_SIZE < user_bounds_in_proj[1]:
continue
if row * TILE_SIZE >= user_bounds_in_proj[3]:
continue

tasks.append(
Task(
application=application,
projection=projection,
bounds=(
col * TILE_SIZE,
row * TILE_SIZE,
(col + 1) * TILE_SIZE,
(row + 1) * TILE_SIZE,
),
time_range=time_range,
out_path=out_path,
)
for col, row in enumerate_tiles_in_zone(utm_zone):
if user_bounds_in_proj is not None:
# Check if this task intersects the bounds specified by the user.
if (col + 1) * TILE_SIZE < user_bounds_in_proj[0]:
continue
if col * TILE_SIZE >= user_bounds_in_proj[2]:
continue
if (row + 1) * TILE_SIZE < user_bounds_in_proj[1]:
continue
if row * TILE_SIZE >= user_bounds_in_proj[3]:
continue

tasks.append(
Task(
application=application,
projection=projection,
bounds=(
col * TILE_SIZE,
row * TILE_SIZE,
(col + 1) * TILE_SIZE,
(row + 1) * TILE_SIZE,
),
time_range=time_range,
out_path=out_path,
)
)

logger.info("Got %d total tasks", len(tasks))

print(f"Got {len(tasks)} total tasks")
# Remove tasks where outputs are already computed.
existing_output_fnames = {out_fname.name for out_fname in UPath(out_path).iterdir()}
tasks = [
task
for task in tasks
if task.get_output_fname().name not in existing_output_fnames
]
logger.info("Got %d tasks that are uncompleted", len(tasks))

# Sample tasks down to user-provided count (max # tasks to run), if provided.
if count is not None and len(tasks) > count:
tasks = random.sample(tasks, count)
logger.info("Randomly sampled %d tasks", len(tasks))

# Convert tasks to jobs for use with rslp.common.worker.
# This is what will be written to the Pub/Sub topic.
jobs = []
for i in range(0, len(tasks), batch_size):
cur_tasks = tasks[i : i + batch_size]
Expand Down

0 comments on commit 298de6a

Please sign in to comment.