Skip to content

Commit

Permalink
[shortfin-sd] Reusable service command buffers (nod-ai#963)
Browse files Browse the repository at this point in the history
This PR non-trivially restructures the SDXL inference process to
interface with reusable command buffers.

At service startup, we attach to each fiber a set of reusable command
buffers (a set of empty device arrays to be reused throughout the
inference process), and upon assigning each exec request a fiber (now
the "meta_fiber" equipped with command buffers), we pop an appropriate
command buffer from the meta fiber, mount the command buffer on the
request, and return the command buffer to the meta_fiber for reuse.

Each meta_fiber is equipped with an explicit number of command buffers,
and if this count is exceeded by the number of InferenceExecProcesses
assigned to the fiber, a new command buffer will be initialized ad-hoc.
This _shouldn't_ happen if service is managed properly, as we shouldn't
oversubscribe fibers too heavily (thundering herd).

These changes also simplify the request structure throughout the
inference processor, operating on a _single_ request as opposed to a
bundle. This greatly simplifies the work required to enable larger batch
sizes, and reduces how much we need to handle program I/O during
execution.

We also remove all `await device` usage where it wasn't required. It
should only be used after VAE decode, where we are done with the GPU and
need to use the result contents of a device->host transfer. fyi
@daveliddell

---------

Co-authored-by: Ean Garvey <ean.garvey@amd.com>
  • Loading branch information
2 people authored and renxida committed Feb 19, 2025
1 parent 182bb9b commit f9fcc08
Show file tree
Hide file tree
Showing 5 changed files with 908 additions and 378 deletions.
7 changes: 6 additions & 1 deletion shortfin/python/shortfin_apps/sd/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ def max_vae_batch_size(self) -> int:

@property
def all_batch_sizes(self) -> list:
return [self.clip_batch_sizes, self.unet_batch_sizes, self.vae_batch_sizes]
intersection = list(
set(self.clip_batch_sizes)
& set(self.unet_batch_sizes)
& set(self.vae_batch_sizes)
)
return intersection

@property
def max_batch_size(self):
Expand Down
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
gen_req: GenerateReqInput,
responder: FastAPIResponder,
):
super().__init__(fiber=service.fibers[0])
super().__init__(fiber=service.meta_fibers[0].fiber)
self.gen_req = gen_req
self.responder = responder
self.batcher = service.batcher
Expand Down
144 changes: 89 additions & 55 deletions shortfin/python/shortfin_apps/sd/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import shortfin as sf
import shortfin.array as sfnp
import numpy as np

from .io_struct import GenerateReqInput

Expand Down Expand Up @@ -41,25 +42,24 @@ class InferenceExecRequest(sf.Message):

def __init__(
self,
prompt: str | None = None,
neg_prompt: str | None = None,
prompt: str | list[str] | None = None,
neg_prompt: str | list[str] | None = None,
height: int | None = None,
width: int | None = None,
steps: int | None = None,
guidance_scale: float | sfnp.device_array | None = None,
seed: int | None = None,
input_ids: list[list[int]] | None = None,
guidance_scale: float | list[float] | sfnp.device_array | None = None,
seed: int | list[int] | None = None,
input_ids: list[list[int]]
| list[list[list[int]]]
| list[sfnp.device_array]
| None = None,
sample: sfnp.device_array | None = None,
prompt_embeds: sfnp.device_array | None = None,
text_embeds: sfnp.device_array | None = None,
timesteps: sfnp.device_array | None = None,
time_ids: sfnp.device_array | None = None,
denoised_latents: sfnp.device_array | None = None,
image_array: sfnp.device_array | None = None,
):
super().__init__()
self.command_buffer = None
self.print_debug = True

self.batch_size = 1
self.phases = {}
self.phase = None
self.height = height
Expand All @@ -74,22 +74,15 @@ def __init__(
self.seed = seed

# Encode phase.
# This is a list of sequenced positive and negative token ids and pooler token ids.
# This is a list of sequenced positive and negative token ids and pooler token ids (tokenizer outputs)
self.input_ids = input_ids

# Denoise phase.
self.prompt_embeds = prompt_embeds
self.text_embeds = text_embeds
self.sample = sample
self.steps = steps
self.timesteps = timesteps
self.time_ids = time_ids
self.guidance_scale = guidance_scale
self.steps = steps

# Decode phase.
self.denoised_latents = denoised_latents

# Postprocess.
self.image_array = image_array

self.result_image = None
Expand All @@ -104,35 +97,49 @@ def __init__(

self.post_init()

@staticmethod
def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest":
gen_inputs = [
"prompt",
"neg_prompt",
"height",
"width",
"steps",
"guidance_scale",
"seed",
"input_ids",
]
rec_inputs = {}
for item in gen_inputs:
received = getattr(gen_req, item, None)
if isinstance(received, list):
if index >= (len(received)):
if len(received) == 1:
rec_input = received[0]
else:
logging.error(
"Inputs in request must be singular or as many as the list of prompts."
)
else:
rec_input = received[index]
else:
rec_input = received
rec_inputs[item] = rec_input
return InferenceExecRequest(**rec_inputs)
def set_command_buffer(self, cb):
# Input IDs for CLIP if they are used as inputs instead of prompts.
if self.input_ids is not None:
# Take a batch of sets of input ids as ndarrays and fill cb.input_ids
host_arrs = [None] * len(cb.input_ids)
for idx, arr in enumerate(cb.input_ids):
host_arrs[idx] = arr.for_transfer()
for i in range(cb.sample.shape[0]):
with host_arrs[idx].view(i).map(write=True, discard=True) as m:

# TODO: fix this attr redundancy
np_arr = self.input_ids[i][idx]

m.fill(np_arr)
cb.input_ids[idx].copy_from(host_arrs[idx])

# Same for noisy latents if they are explicitly provided as a numpy array.
if self.sample is not None:
sample_host = cb.sample.for_transfer()
with sample_host.map(discard=True) as m:
m.fill(self.sample.tobytes())
cb.sample.copy_from(sample_host)

# Copy other inference parameters for denoise to device arrays.
steps_arr = list(range(0, self.steps))
steps_host = cb.steps_arr.for_transfer()
steps_host.items = steps_arr
cb.steps_arr.copy_from(steps_host)

num_step_host = cb.num_steps.for_transfer()
num_step_host.items = [self.steps]
cb.num_steps.copy_from(num_step_host)

guidance_host = cb.guidance_scale.for_transfer()
with guidance_host.map(discard=True) as m:
# TODO: do this without numpy
np_arr = np.asarray(self.guidance_scale, dtype="float16")

m.fill(np_arr)
cb.guidance_scale.copy_from(guidance_host)

self.command_buffer = cb
return

def post_init(self):
"""Determines necessary inference phases and tags them with static program parameters."""
Expand All @@ -157,15 +164,11 @@ def check_phase(self, phase: InferencePhase):
meta = [self.width, self.height]
return required, meta
case InferencePhase.DENOISE:
required = not self.denoised_latents
required = True
meta = [self.width, self.height, self.steps]
return required, meta
case InferencePhase.ENCODE:
p_results = [
self.prompt_embeds,
self.text_embeds,
]
required = any([inp is None for inp in p_results])
required = True
return required, None
case InferencePhase.PREPARE:
p_results = [self.sample, self.input_ids]
Expand All @@ -179,6 +182,37 @@ def reset(self, phase: InferencePhase):
self.done = sf.VoidFuture()
self.return_host_array = True

@staticmethod
def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest":
gen_inputs = [
"prompt",
"neg_prompt",
"height",
"width",
"steps",
"guidance_scale",
"seed",
"input_ids",
]
rec_inputs = {}
for item in gen_inputs:
received = getattr(gen_req, item, None)
if isinstance(received, list):
if index >= (len(received)):
if len(received) == 1:
rec_input = received[0]
else:
logging.error(
"Inputs in request must be singular or as many as the list of prompts."
)
else:
rec_input = received[index]
else:
rec_input = received
rec_inputs[item] = rec_input
req = InferenceExecRequest(**rec_inputs)
return req


class StrobeMessage(sf.Message):
"""Sent to strobe a queue with fake activity (generate a wakeup)."""
Expand Down
Loading

0 comments on commit f9fcc08

Please sign in to comment.