Skip to content

Commit

Permalink
Run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Feb 13, 2025
1 parent cfc4f69 commit f9b206a
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 55 deletions.
4 changes: 3 additions & 1 deletion shortfin/python/shortfin_apps/sd/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def max_vae_batch_size(self) -> int:
@property
def all_batch_sizes(self) -> list:
intersection = list(
set(self.clip_batch_sizes) & set(self.unet_batch_sizes) & set(self.vae_batch_sizes)
set(self.clip_batch_sizes)
& set(self.unet_batch_sizes)
& set(self.vae_batch_sizes)
)
return intersection

Expand Down
7 changes: 5 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def __init__(
steps: int | None = None,
guidance_scale: float | list[float] | sfnp.device_array | None = None,
seed: int | list[int] | None = None,
input_ids: list[list[int]] | list[list[list[int]]] | list[sfnp.device_array] | None = None,
input_ids: list[list[int]]
| list[list[list[int]]]
| list[sfnp.device_array]
| None = None,
sample: sfnp.device_array | None = None,
prompt_embeds: sfnp.device_array | None = None,
text_embeds: sfnp.device_array | None = None,
Expand Down Expand Up @@ -150,7 +153,7 @@ def set_command_buffer(self, cb):

self.command_buffer = cb
return

def post_init(self):
"""Determines necessary inference phases and tags them with static program parameters."""
for p in reversed(list(InferencePhase)):
Expand Down
102 changes: 50 additions & 52 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,29 @@ def __init__(
self.workers.append(worker)
for i in range(self.fibers_per_device):
worker_idx = idx * workers_per_device + i % workers_per_device
tgt_worker = self.workers[
worker_idx
]
tgt_worker = self.workers[worker_idx]
raw_fiber = sysman.ls.create_fiber(tgt_worker, devices=[device])
fiber = self.equip_fiber(raw_fiber, len(self.fibers), worker_idx)
self.fibers.append(fiber)
self.idle_fibers.append(fiber)
for idx in range(len(self.workers)):
self.inference_programs[idx] = {}
self.inference_functions[idx] = {}

# Scope dependent objects.
self.batcher = BatcherProcess(self)

def equip_fiber(self, fiber, idx, worker_idx):
EquippedFiber = namedtuple('EquippedFiber', ['fiber', 'idx', 'worker_idx', 'device', 'command_buffers'])
EquippedFiber = namedtuple(
"EquippedFiber", ["fiber", "idx", "worker_idx", "device", "command_buffers"]
)
cbs_per_fiber = 1
cbs = []
for _ in range(cbs_per_fiber):
for batch_size in self.model_params.all_batch_sizes:
cbs.append(initialize_command_buffer(fiber, self.model_params, batch_size))
cbs.append(
initialize_command_buffer(fiber, self.model_params, batch_size)
)

return EquippedFiber(fiber, idx, worker_idx, fiber.device(0), cbs)

Expand All @@ -118,7 +120,7 @@ def get_worker_index(self, fiber):
(fiber_idx - fiber_idx % self.fibers_per_worker) / self.fibers_per_worker
)
return worker_idx

def load_inference_module(self, vmfb_path: Path, component: str = None):
if not self.inference_modules.get(component):
self.inference_modules[component] = []
Expand Down Expand Up @@ -227,6 +229,7 @@ def __repr__(self):
f")"
)


########################################################################################
# Batcher
########################################################################################
Expand Down Expand Up @@ -298,7 +301,9 @@ async def board_flights(self):
logger.debug("Waiting for an idle fiber...")
return
fiber = self.service.idle_fibers.pop(0)
logger.debug(f"Sending batch to fiber {fiber.idx} (worker {fiber.worker_idx})")
logger.debug(
f"Sending batch to fiber {fiber.idx} (worker {fiber.worker_idx})"
)
await self.board(batch["reqs"][0], fiber=fiber)
if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.append(fiber)
Expand Down Expand Up @@ -366,10 +371,12 @@ def assign_command_buffer(self, request: InferenceExecRequest):
self.exec_request.set_command_buffer(cb)
self.meta_fiber.command_buffers.remove(cb)
return
cb = initialize_command_buffer(self.fiber, self.service.model_params, request.batch_size)
cb = initialize_command_buffer(
self.fiber, self.service.model_params, request.batch_size
)
self.exec_request.set_command_buffer(cb)
return

@measure(type="exec", task="inference process")
async def run(self):
try:
Expand Down Expand Up @@ -480,16 +487,9 @@ async def _denoise(self, device):
"INVOKE %r",
fns["init"],
)
(
cb.latents,
cb.time_ids,
cb.timesteps,
cb.sigmas,
) = await fns["init"](
cb.sample,
cb.num_steps,
fiber=self.fiber
)
(cb.latents, cb.time_ids, cb.timesteps, cb.sigmas,) = await fns[
"init"
](cb.sample, cb.num_steps, fiber=self.fiber)
accum_step_duration = 0 # Accumulated duration for all steps
for i, t in tqdm(
enumerate(range(self.exec_request.steps)),
Expand All @@ -502,18 +502,9 @@ async def _denoise(self, device):
"INVOKE %r",
fns["scale"],
)
(
cb.latent_model_input,
cb.t,
cb.sigma,
cb.next_sigma,
) = await fns["scale"](
cb.latents,
step,
cb.timesteps,
cb.sigmas,
fiber=self.fiber
)
(cb.latent_model_input, cb.t, cb.sigma, cb.next_sigma,) = await fns[
"scale"
](cb.latents, step, cb.timesteps, cb.sigmas, fiber=self.fiber)
logger.debug(
"INVOKE %r",
fns["unet"],
Expand All @@ -524,19 +515,15 @@ async def _denoise(self, device):
cb.prompt_embeds,
cb.text_embeds,
cb.time_ids,
cb.guidance_scale,
fiber=self.fiber
cb.guidance_scale,
fiber=self.fiber,
)
logger.debug(
"INVOKE %r",
fns["step"],
)
(cb.latents,) = await fns["step"](
cb.noise_pred,
cb.latents,
cb.sigma,
cb.next_sigma,
fiber=self.fiber
cb.noise_pred, cb.latents, cb.sigma, cb.next_sigma, fiber=self.fiber
)
duration = time.time() - start
accum_step_duration += duration
Expand Down Expand Up @@ -587,8 +574,9 @@ async def _postprocess(self, device):
image = base64.b64encode(image_bytes).decode("utf-8")
self.exec_request.result_image = image
return

def initialize_command_buffer(fiber, model_params: ModelParams, batch_size:int=1):


def initialize_command_buffer(fiber, model_params: ModelParams, batch_size: int = 1):
bs = batch_size
cfg_bs = batch_size * 2
h = model_params.dims[0][0]
Expand All @@ -614,40 +602,50 @@ def initialize_command_buffer(fiber, model_params: ModelParams, batch_size:int=1
],
# DENOISE
"prompt_embeds": sfnp.device_array.for_device(
device, [cfg_bs, model_params.max_seq_len, 2048] , model_params.unet_dtype
device, [cfg_bs, model_params.max_seq_len, 2048], model_params.unet_dtype
),
"text_embeds": sfnp.device_array.for_device(
device, [cfg_bs, 1280], model_params.unet_dtype
),
"sample": sfnp.device_array.for_device(
device, [bs, c, h//8, w//8], model_params.unet_dtype
device, [bs, c, h // 8, w // 8], model_params.unet_dtype
),
"latents": sfnp.device_array.for_device(
device, [bs, c, h//8, w//8], model_params.unet_dtype
device, [bs, c, h // 8, w // 8], model_params.unet_dtype
),
"noise_pred": sfnp.device_array.for_device(
device, [bs, c, h//8, w//8], model_params.unet_dtype
device, [bs, c, h // 8, w // 8], model_params.unet_dtype
),
"num_steps": sfnp.device_array.for_device(device, [1], sfnp.sint64),
"steps_arr": sfnp.device_array.for_device(device, [100], sfnp.sint64),
"timesteps": sfnp.device_array.for_device(device, [100], sfnp.float32),
"sigmas": sfnp.device_array.for_device(device, [100], sfnp.float32),
"latent_model_input": sfnp.device_array.for_device(
device, [bs, c, h//8, w//8], model_params.unet_dtype
device, [bs, c, h // 8, w // 8], model_params.unet_dtype
),
"t": sfnp.device_array.for_device(device, [1], model_params.unet_dtype),
"sigma": sfnp.device_array.for_device(device, [1], model_params.unet_dtype),
"next_sigma": sfnp.device_array.for_device(device, [1], model_params.unet_dtype),
"time_ids": sfnp.device_array.for_device(device, [bs, 6], model_params.unet_dtype),
"guidance_scale": sfnp.device_array.for_device(device, [bs], model_params.unet_dtype),
"next_sigma": sfnp.device_array.for_device(
device, [1], model_params.unet_dtype
),
"time_ids": sfnp.device_array.for_device(
device, [bs, 6], model_params.unet_dtype
),
"guidance_scale": sfnp.device_array.for_device(
device, [bs], model_params.unet_dtype
),
# VAE
"images": sfnp.device_array.for_device(device, [bs, 3, h, w], model_params.vae_dtype),
"images_host": sfnp.device_array.for_host(device, [bs, 3, h, w], model_params.vae_dtype),
"images": sfnp.device_array.for_device(
device, [bs, 3, h, w], model_params.vae_dtype
),
"images_host": sfnp.device_array.for_host(
device, [bs, 3, h, w], model_params.vae_dtype
),
}

class ServiceCmdBuffer:
def __init__(self, d):
self.__dict__ = d

cb = ServiceCmdBuffer(datas)
return cb

0 comments on commit f9b206a

Please sign in to comment.