Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Feb 13, 2025
1 parent c7c01f0 commit 92b9e39
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 33 deletions.
11 changes: 0 additions & 11 deletions shortfin/python/shortfin_apps/sd/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ def __init__(
| list[sfnp.device_array]
| None = None,
sample: sfnp.device_array | None = None,
prompt_embeds: sfnp.device_array | None = None,
text_embeds: sfnp.device_array | None = None,
timesteps: sfnp.device_array | None = None,
time_ids: sfnp.device_array | None = None,
denoised_latents: sfnp.device_array | None = None,
image_array: sfnp.device_array | None = None,
):
super().__init__()
Expand All @@ -83,14 +78,8 @@ def __init__(
self.input_ids = input_ids

# Denoise phase.
self.prompt_embeds = prompt_embeds
self.text_embeds = text_embeds
self.sample = sample
self.steps = steps
self.steps_arr = None
self.timesteps = timesteps
self.time_ids = time_ids
self.guidance_scale = guidance_scale

# Decode phase.
self.denoised_latents = denoised_latents
Expand Down
44 changes: 22 additions & 22 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def __init__(
worker_idx = idx * workers_per_device + i % workers_per_device
tgt_worker = self.workers[worker_idx]
raw_fiber = sysman.ls.create_fiber(tgt_worker, devices=[device])
fiber = self.equip_fiber(raw_fiber, len(self.fibers), worker_idx)
self.fibers.append(fiber)
self.idle_fibers.append(fiber)
meta_fiber = self.equip_fiber(raw_fiber, len(self.fibers), worker_idx)
self.meta_fibers.append(meta_fiber)
self.idle_meta_fibers.append(meta_fiber)
for idx in range(len(self.workers)):
self.inference_programs[idx] = {}
self.inference_functions[idx] = {}
Expand All @@ -99,8 +99,8 @@ def __init__(
self.batcher = BatcherProcess(self)

def equip_fiber(self, fiber, idx, worker_idx):
EquippedFiber = namedtuple(
"EquippedFiber", ["fiber", "idx", "worker_idx", "device", "command_buffers"]
MetaFiber = namedtuple(
"MetaFiber", ["fiber", "idx", "worker_idx", "device", "command_buffers"]
)
cbs_per_fiber = 1
cbs = []
Expand All @@ -110,7 +110,7 @@ def equip_fiber(self, fiber, idx, worker_idx):
initialize_command_buffer(fiber, self.model_params, batch_size)
)

return EquippedFiber(fiber, idx, worker_idx, fiber.device(0), cbs)
return MetaFiber(fiber, idx, worker_idx, fiber.device(0), cbs)

def get_worker_index(self, fiber):
if fiber not in self.fibers:
Expand Down Expand Up @@ -244,14 +244,14 @@ class BatcherProcess(sf.Process):
STROBE_LONG_DELAY = 1

def __init__(self, service: GenerateService):
super().__init__(fiber=service.fibers[0].fiber)
super().__init__(fiber=service.meta_fibers[0].fiber)
self.service = service
self.batcher_infeed = self.system.create_queue()
self.pending_requests: set[InferenceExecRequest] = set()
self.strobe_enabled = True
self.strobes: int = 0
self.ideal_batch_size: int = max(service.model_params.all_batch_sizes)
self.num_fibers = len(service.fibers)
self.num_fibers = len(service.meta_fibers)

def shutdown(self):
self.batcher_infeed.close()
Expand Down Expand Up @@ -297,16 +297,16 @@ async def board_flights(self):
batches = self.sort_batches()
for batch in batches.values():
# Assign the batch to the next idle fiber.
if len(self.service.idle_fibers) == 0:
if len(self.service.idle_meta_fibers) == 0:
logger.debug("Waiting for an idle fiber...")
return
fiber = self.service.idle_fibers.pop(0)
fiber = self.service.idle_meta_fibers.pop(0)
logger.debug(
f"Sending batch to fiber {fiber.idx} (worker {fiber.worker_idx})"
)
await self.board(batch["reqs"][0], fiber=fiber)
await self.board(batch["reqs"][0], meta_fiber=fiber)
if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.append(fiber)
self.service.idle_meta_fibers.append(fiber)

def sort_batches(self):
"""Files pending requests into sorted batches suitable for program invocations."""
Expand Down Expand Up @@ -339,8 +339,8 @@ def sort_batches(self):
}
return batches

async def board(self, request, fiber):
exec_process = InferenceExecutorProcess(self.service, fiber)
async def board(self, request, meta_fiber):
exec_process = InferenceExecutorProcess(self.service, meta_fiber)
exec_process.exec_request = request
self.pending_requests.remove(request)
exec_process.launch()
Expand All @@ -357,12 +357,12 @@ class InferenceExecutorProcess(sf.Process):
def __init__(
self,
service: GenerateService,
fiber,
meta_fiber,
):
super().__init__(fiber=fiber.fiber)
super().__init__(fiber=meta_fiber.fiber)
self.service = service
self.meta_fiber = fiber
self.worker_index = fiber.worker_idx
self.meta_fiber = meta_fiber
self.worker_index = meta_fiber.worker_idx
self.exec_request: InferenceExecRequest = None

def assign_command_buffer(self, request: InferenceExecRequest):
Expand Down Expand Up @@ -396,16 +396,16 @@ async def run(self):
await self._decode(device=device)
if phases[InferencePhase.POSTPROCESS]["required"]:
await self._postprocess(device=device)
await device
self.exec_request.done.set_success()
self.meta_fiber.command_buffers.append(self.exec_request.command_buffer)
if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.append(self.meta_fiber)

except Exception:
logger.exception("Fatal error in image generation")
# TODO: Cancel and set error correctly
self.exec_request.done.set_success()

self.meta_fiber.command_buffers.append(self.exec_request.command_buffer)
if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER:
self.service.idle_meta_fibers.append(self.meta_fiber)

async def _prepare(self, device):
# Tokenize prompts and negative prompts. We tokenize in bs1 for now and join later.
Expand Down

0 comments on commit 92b9e39

Please sign in to comment.