diff --git a/shortfin/python/shortfin_apps/sd/components/config_struct.py b/shortfin/python/shortfin_apps/sd/components/config_struct.py index 2b954c18b..f460d1247 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/config_struct.py @@ -94,7 +94,12 @@ def max_vae_batch_size(self) -> int: @property def all_batch_sizes(self) -> list: - return [self.clip_batch_sizes, self.unet_batch_sizes, self.vae_batch_sizes] + intersection = list( + set(self.clip_batch_sizes) + & set(self.unet_batch_sizes) + & set(self.vae_batch_sizes) + ) + return intersection @property def max_batch_size(self): diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index 62ac5e855..ab3f2ff08 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -75,7 +75,7 @@ def __init__( gen_req: GenerateReqInput, responder: FastAPIResponder, ): - super().__init__(fiber=service.fibers[0]) + super().__init__(fiber=service.meta_fibers[0].fiber) self.gen_req = gen_req self.responder = responder self.batcher = service.batcher diff --git a/shortfin/python/shortfin_apps/sd/components/messages.py b/shortfin/python/shortfin_apps/sd/components/messages.py index 1f275b4e5..6b0117ccd 100644 --- a/shortfin/python/shortfin_apps/sd/components/messages.py +++ b/shortfin/python/shortfin_apps/sd/components/messages.py @@ -10,6 +10,7 @@ import shortfin as sf import shortfin.array as sfnp +import numpy as np from .io_struct import GenerateReqInput @@ -41,25 +42,24 @@ class InferenceExecRequest(sf.Message): def __init__( self, - prompt: str | None = None, - neg_prompt: str | None = None, + prompt: str | list[str] | None = None, + neg_prompt: str | list[str] | None = None, height: int | None = None, width: int | None = None, steps: int | None = None, - guidance_scale: float | sfnp.device_array | None = None, - seed: int | None = None, - input_ids: list[list[int]] | None = None, + guidance_scale: float | list[float] | sfnp.device_array | None = None, + seed: int | list[int] | None = None, + input_ids: list[list[int]] + | list[list[list[int]]] + | list[sfnp.device_array] + | None = None, sample: sfnp.device_array | None = None, - prompt_embeds: sfnp.device_array | None = None, - text_embeds: sfnp.device_array | None = None, - timesteps: sfnp.device_array | None = None, - time_ids: sfnp.device_array | None = None, - denoised_latents: sfnp.device_array | None = None, image_array: sfnp.device_array | None = None, ): super().__init__() + self.command_buffer = None self.print_debug = True - + self.batch_size = 1 self.phases = {} self.phase = None self.height = height @@ -74,22 +74,15 @@ def __init__( self.seed = seed # Encode phase. - # This is a list of sequenced positive and negative token ids and pooler token ids. + # This is a list of sequenced positive and negative token ids and pooler token ids (tokenizer outputs) self.input_ids = input_ids # Denoise phase. - self.prompt_embeds = prompt_embeds - self.text_embeds = text_embeds self.sample = sample - self.steps = steps - self.timesteps = timesteps - self.time_ids = time_ids self.guidance_scale = guidance_scale + self.steps = steps # Decode phase. - self.denoised_latents = denoised_latents - - # Postprocess. self.image_array = image_array self.result_image = None @@ -104,35 +97,49 @@ def __init__( self.post_init() - @staticmethod - def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": - gen_inputs = [ - "prompt", - "neg_prompt", - "height", - "width", - "steps", - "guidance_scale", - "seed", - "input_ids", - ] - rec_inputs = {} - for item in gen_inputs: - received = getattr(gen_req, item, None) - if isinstance(received, list): - if index >= (len(received)): - if len(received) == 1: - rec_input = received[0] - else: - logging.error( - "Inputs in request must be singular or as many as the list of prompts." - ) - else: - rec_input = received[index] - else: - rec_input = received - rec_inputs[item] = rec_input - return InferenceExecRequest(**rec_inputs) + def set_command_buffer(self, cb): + # Input IDs for CLIP if they are used as inputs instead of prompts. + if self.input_ids is not None: + # Take a batch of sets of input ids as ndarrays and fill cb.input_ids + host_arrs = [None] * len(cb.input_ids) + for idx, arr in enumerate(cb.input_ids): + host_arrs[idx] = arr.for_transfer() + for i in range(cb.sample.shape[0]): + with host_arrs[idx].view(i).map(write=True, discard=True) as m: + + # TODO: fix this attr redundancy + np_arr = self.input_ids[i][idx] + + m.fill(np_arr) + cb.input_ids[idx].copy_from(host_arrs[idx]) + + # Same for noisy latents if they are explicitly provided as a numpy array. + if self.sample is not None: + sample_host = cb.sample.for_transfer() + with sample_host.map(discard=True) as m: + m.fill(self.sample.tobytes()) + cb.sample.copy_from(sample_host) + + # Copy other inference parameters for denoise to device arrays. + steps_arr = list(range(0, self.steps)) + steps_host = cb.steps_arr.for_transfer() + steps_host.items = steps_arr + cb.steps_arr.copy_from(steps_host) + + num_step_host = cb.num_steps.for_transfer() + num_step_host.items = [self.steps] + cb.num_steps.copy_from(num_step_host) + + guidance_host = cb.guidance_scale.for_transfer() + with guidance_host.map(discard=True) as m: + # TODO: do this without numpy + np_arr = np.asarray(self.guidance_scale, dtype="float16") + + m.fill(np_arr) + cb.guidance_scale.copy_from(guidance_host) + + self.command_buffer = cb + return def post_init(self): """Determines necessary inference phases and tags them with static program parameters.""" @@ -157,15 +164,11 @@ def check_phase(self, phase: InferencePhase): meta = [self.width, self.height] return required, meta case InferencePhase.DENOISE: - required = not self.denoised_latents + required = True meta = [self.width, self.height, self.steps] return required, meta case InferencePhase.ENCODE: - p_results = [ - self.prompt_embeds, - self.text_embeds, - ] - required = any([inp is None for inp in p_results]) + required = True return required, None case InferencePhase.PREPARE: p_results = [self.sample, self.input_ids] @@ -179,6 +182,37 @@ def reset(self, phase: InferencePhase): self.done = sf.VoidFuture() self.return_host_array = True + @staticmethod + def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": + gen_inputs = [ + "prompt", + "neg_prompt", + "height", + "width", + "steps", + "guidance_scale", + "seed", + "input_ids", + ] + rec_inputs = {} + for item in gen_inputs: + received = getattr(gen_req, item, None) + if isinstance(received, list): + if index >= (len(received)): + if len(received) == 1: + rec_input = received[0] + else: + logging.error( + "Inputs in request must be singular or as many as the list of prompts." + ) + else: + rec_input = received[index] + else: + rec_input = received + rec_inputs[item] = rec_input + req = InferenceExecRequest(**rec_inputs) + return req + class StrobeMessage(sf.Message): """Sent to strobe a queue with fake activity (generate a wakeup).""" diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index d8398d0c0..be0e0bf9e 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -11,6 +11,7 @@ from tqdm.auto import tqdm from pathlib import Path from PIL import Image +from collections import namedtuple import base64 import shortfin as sf @@ -75,8 +76,8 @@ def __init__( self.fibers_per_worker = int(fibers_per_device / workers_per_device) self.workers = [] - self.fibers = [] - self.idle_fibers = set() + self.meta_fibers = [] + self.idle_meta_fibers = [] # For each worker index we create one on each device, and add their fibers to the idle set. # This roughly ensures that the first picked fibers are distributed across available devices. for idx, device in enumerate(self.sysman.ls.devices): @@ -84,26 +85,34 @@ def __init__( worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") self.workers.append(worker) for i in range(self.fibers_per_device): - tgt_worker = self.workers[ - idx * workers_per_device + i % workers_per_device - ] - fiber = sysman.ls.create_fiber(tgt_worker, devices=[device]) - self.fibers.append(fiber) - self.idle_fibers.add(fiber) + worker_idx = idx * workers_per_device + i % workers_per_device + tgt_worker = self.workers[worker_idx] + raw_fiber = sysman.ls.create_fiber(tgt_worker, devices=[device]) + meta_fiber = self.equip_fiber( + raw_fiber, len(self.meta_fibers), worker_idx + ) + self.meta_fibers.append(meta_fiber) + self.idle_meta_fibers.append(meta_fiber) for idx in range(len(self.workers)): self.inference_programs[idx] = {} self.inference_functions[idx] = {} + # Scope dependent objects. self.batcher = BatcherProcess(self) - def get_worker_index(self, fiber): - if fiber not in self.fibers: - raise ValueError("A worker was requested from a rogue fiber.") - fiber_idx = self.fibers.index(fiber) - worker_idx = int( - (fiber_idx - fiber_idx % self.fibers_per_worker) / self.fibers_per_worker + def equip_fiber(self, fiber, idx, worker_idx): + MetaFiber = namedtuple( + "MetaFiber", ["fiber", "idx", "worker_idx", "device", "command_buffers"] ) - return worker_idx + cbs_per_fiber = 1 + cbs = [] + for _ in range(cbs_per_fiber): + for batch_size in self.model_params.all_batch_sizes: + cbs.append( + initialize_command_buffer(fiber, self.model_params, batch_size) + ) + + return MetaFiber(fiber, idx, worker_idx, fiber.device(0), cbs) def load_inference_module(self, vmfb_path: Path, component: str = None): if not self.inference_modules.get(component): @@ -139,9 +148,9 @@ def start(self): ] for worker_idx, worker in enumerate(self.workers): - worker_devices = self.fibers[ + worker_devices = self.meta_fibers[ worker_idx * (self.fibers_per_worker) - ].raw_devices + ].fiber.raw_devices logger.info( f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}" ) @@ -228,14 +237,14 @@ class BatcherProcess(sf.Process): STROBE_LONG_DELAY = 1 def __init__(self, service: GenerateService): - super().__init__(fiber=service.fibers[0]) + super().__init__(fiber=service.meta_fibers[0].fiber) self.service = service self.batcher_infeed = self.system.create_queue() self.pending_requests: set[InferenceExecRequest] = set() self.strobe_enabled = True self.strobes: int = 0 - self.ideal_batch_size: int = max(service.model_params.max_batch_size) - self.num_fibers = len(service.fibers) + self.ideal_batch_size: int = max(service.model_params.all_batch_sizes) + self.num_fibers = len(service.meta_fibers) def shutdown(self): self.batcher_infeed.close() @@ -265,12 +274,12 @@ async def run(self): else: logger.error("Illegal message received by batcher: %r", item) - self.board_flights() + await self.board_flights() self.strobe_enabled = True await strober_task - def board_flights(self): + async def board_flights(self): waiting_count = len(self.pending_requests) if waiting_count == 0: return @@ -281,15 +290,16 @@ def board_flights(self): batches = self.sort_batches() for batch in batches.values(): # Assign the batch to the next idle fiber. - if len(self.service.idle_fibers) == 0: + if len(self.service.idle_meta_fibers) == 0: + logger.debug("Waiting for an idle fiber...") return - fiber = self.service.idle_fibers.pop() - fiber_idx = self.service.fibers.index(fiber) - worker_idx = self.service.get_worker_index(fiber) - logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})") - self.board(batch["reqs"], fiber=fiber) + meta_fiber = self.service.idle_meta_fibers.pop(0) + logger.debug( + f"Sending batch to fiber {meta_fiber.idx} (worker {meta_fiber.worker_idx})" + ) + await self.board(batch["reqs"][0], meta_fiber=meta_fiber) if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER: - self.service.idle_fibers.add(fiber) + self.service.idle_meta_fibers.append(meta_fiber) def sort_batches(self): """Files pending requests into sorted batches suitable for program invocations.""" @@ -322,19 +332,11 @@ def sort_batches(self): } return batches - def board(self, request_bundle, fiber): - pending = request_bundle - if len(pending) == 0: - return - exec_process = InferenceExecutorProcess(self.service, fiber) - for req in pending: - if len(exec_process.exec_requests) >= self.ideal_batch_size: - break - exec_process.exec_requests.append(req) - if exec_process.exec_requests: - for flighted_request in exec_process.exec_requests: - self.pending_requests.remove(flighted_request) - exec_process.launch() + async def board(self, request, meta_fiber): + exec_process = InferenceExecutorProcess(self.service, meta_fiber) + exec_process.exec_request = request + self.pending_requests.remove(request) + exec_process.launch() ######################################################################################## @@ -348,368 +350,296 @@ class InferenceExecutorProcess(sf.Process): def __init__( self, service: GenerateService, - fiber, + meta_fiber, ): - super().__init__(fiber=fiber) + super().__init__(fiber=meta_fiber.fiber) self.service = service - self.worker_index = self.service.get_worker_index(fiber) - self.exec_requests: list[InferenceExecRequest] = [] + self.meta_fiber = meta_fiber + self.worker_index = meta_fiber.worker_idx + self.exec_request: InferenceExecRequest = None + + def assign_command_buffer(self, request: InferenceExecRequest): + for cb in self.meta_fiber.command_buffers: + if cb.sample.shape[0] == self.exec_request.batch_size: + self.exec_request.set_command_buffer(cb) + self.meta_fiber.command_buffers.remove(cb) + return + cb = initialize_command_buffer( + self.fiber, self.service.model_params, request.batch_size + ) + self.exec_request.set_command_buffer(cb) + return @measure(type="exec", task="inference process") async def run(self): try: - phase = None - for req in self.exec_requests: - if phase: - if phase != req.phase: - logger.error("Executor process recieved disjoint batch.") - phase = req.phase - phases = self.exec_requests[0].phases - req_count = len(self.exec_requests) - device0 = self.fiber.device(0) + if not self.exec_request.command_buffer: + self.assign_command_buffer(self.exec_request) + + device = self.fiber.device(0) + phases = self.exec_request.phases if phases[InferencePhase.PREPARE]["required"]: - await self._prepare(device=device0, requests=self.exec_requests) + await self._prepare(device=device) if phases[InferencePhase.ENCODE]["required"]: - await self._encode(device=device0, requests=self.exec_requests) + await self._encode(device=device) if phases[InferencePhase.DENOISE]["required"]: - await self._denoise(device=device0, requests=self.exec_requests) + await self._denoise(device=device) if phases[InferencePhase.DECODE]["required"]: - await self._decode(device=device0, requests=self.exec_requests) + await self._decode(device=device) + # Postprocessing needs the output data to be on the host. Even + # without postprocessing, we're done with the GPU, so we wait for + # it to finish here. + await device if phases[InferencePhase.POSTPROCESS]["required"]: - await self._postprocess(device=device0, requests=self.exec_requests) - await device0 - for i in range(req_count): - req = self.exec_requests[i] - req.done.set_success() - if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER: - self.service.idle_fibers.add(self.fiber) + await self._postprocess(device=device) + self.exec_request.done.set_success() except Exception: logger.exception("Fatal error in image generation") # TODO: Cancel and set error correctly - for req in self.exec_requests: - req.done.set_success() - - async def _prepare(self, device, requests): - for request in requests: - # Tokenize prompts and negative prompts. We tokenize in bs1 for now and join later. + self.exec_request.done.set_success() + + self.meta_fiber.command_buffers.append(self.exec_request.command_buffer) + if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER: + self.service.idle_meta_fibers.append(self.meta_fiber) + + async def _prepare(self, device): + # Tokenize prompts and negative prompts. We tokenize in bs1 for now and join later. + # Tokenize the prompts if the request does not hold input_ids. + batch_ids_lists = [] + cb = self.exec_request.command_buffer + if isinstance(self.exec_request.prompt, str): + self.exec_request.prompt = [self.exec_request.prompt] + if isinstance(self.exec_request.neg_prompt, str): + self.exec_request.neg_prompt = [self.exec_request.neg_prompt] + for i in range(self.exec_request.batch_size): input_ids_list = [] neg_ids_list = [] - ids_list = request.input_ids - # Tokenize the prompts if the request does not hold input_ids. - if ids_list is None: - for tokenizer in self.service.tokenizers: - input_ids = tokenizer.encode(request.prompt).input_ids - input_ids_list.append(input_ids) - neg_ids = tokenizer.encode(request.neg_prompt).input_ids - neg_ids_list.append(neg_ids) - ids_list = [*input_ids_list, *neg_ids_list] - - request.input_ids = ids_list - - # Generate random sample latents. - seed = request.seed - channels = self.service.model_params.num_latents_channels - unet_dtype = self.service.model_params.unet_dtype - latents_shape = ( - 1, - channels, - request.height // 8, - request.width // 8, - ) + for tokenizer in self.service.tokenizers: + input_ids = tokenizer.encode(self.exec_request.prompt[i]).input_ids + input_ids_list.append(input_ids) + neg_ids = tokenizer.encode(self.exec_request.neg_prompt[i]).input_ids + neg_ids_list.append(neg_ids) + ids_list = [*input_ids_list, *neg_ids_list] + batch_ids_lists.append(ids_list) - # Create and populate sample device array. - generator = sfnp.RandomGenerator(seed) - request.sample = sfnp.device_array.for_device( - device, latents_shape, unet_dtype - ) + # Prepare tokenized input ids for CLIP inference + host_arrs = [None] * len(cb.input_ids) + for idx, arr in enumerate(cb.input_ids): + host_arrs[idx] = arr.for_transfer() + for i in range(self.exec_request.batch_size): + with host_arrs[idx].view(i).map(write=True, discard=True) as m: + + # TODO: fix this attr redundancy + np_arr = batch_ids_lists[i][idx] + + m.fill(np_arr) + cb.input_ids[idx].copy_from(host_arrs[idx]) - sample_host = request.sample.for_transfer() - with sample_host.map(discard=True) as m: - m.fill(bytes(1)) + # Generate random sample latents. + seed = self.exec_request.seed - sfnp.fill_randn(sample_host, generator=generator) + # Create and populate sample device array. + generator = sfnp.RandomGenerator(seed) - request.sample.copy_from(sample_host) + sample_host = cb.sample.for_transfer() + with sample_host.map(discard=True) as m: + m.fill(bytes(1)) + + sfnp.fill_randn(sample_host, generator=generator) + + cb.sample.copy_from(sample_host) return - @measure(type="exec", task="encode (CLIP)") - async def _encode(self, device, requests): - req_bs = len(requests) + async def _encode(self, device): + req_bs = self.exec_request.batch_size entrypoints = self.service.inference_functions[self.worker_index]["encode"] - if req_bs not in list(entrypoints.keys()): - for request in requests: - await self._encode(device, [request]) - return + assert req_bs in list(entrypoints.keys()) for bs, fn in entrypoints.items(): if bs == req_bs: break - - # Prepare tokenized input ids for CLIP inference - - clip_inputs = [ - sfnp.device_array.for_device( - device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64 - ), - sfnp.device_array.for_device( - device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64 - ), - sfnp.device_array.for_device( - device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64 - ), - sfnp.device_array.for_device( - device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64 - ), - ] - host_arrs = [None] * 4 - for idx, arr in enumerate(clip_inputs): - host_arrs[idx] = arr.for_transfer() - for i in range(req_bs): - with host_arrs[idx].view(i).map(write=True, discard=True) as m: - - # TODO: fix this attr redundancy - np_arr = requests[i].input_ids[idx] - - m.fill(np_arr) - clip_inputs[idx].copy_from(host_arrs[idx]) - + cb = self.exec_request.command_buffer # Encode tokenized inputs. logger.debug( "INVOKE %r: %s", fn, - "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), + "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(cb.input_ids)]), ) - await device - pe, te = await fn(*clip_inputs, fiber=self.fiber) - - for i in range(req_bs): - cfg_mult = 2 - requests[i].prompt_embeds = pe.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) - requests[i].text_embeds = te.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) - + cb.prompt_embeds, cb.text_embeds = await fn(*cb.input_ids, fiber=self.fiber) return - @measure(type="exec", task="denoise (UNet)") - async def _denoise(self, device, requests): - req_bs = len(requests) - step_count = requests[0].steps - cfg_mult = 2 if self.service.model_params.cfg_mode else 1 - # Produce denoised latents + async def _denoise(self, device): + req_bs = self.exec_request.batch_size entrypoints = self.service.inference_functions[self.worker_index]["denoise"] - if req_bs not in list(entrypoints.keys()): - for request in requests: - await self._denoise(device, [request]) - return + assert req_bs in list(entrypoints.keys()) for bs, fns in entrypoints.items(): if bs == req_bs: break - # Get shape of batched latents. - # This assumes all requests are dense at this point. - latents_shape = [ - req_bs, - self.service.model_params.num_latents_channels, - requests[0].height // 8, - requests[0].width // 8, - ] - # Assume we are doing classifier-free guidance - hidden_states_shape = [ - req_bs * cfg_mult, - self.service.model_params.max_seq_len, - 2048, - ] - text_embeds_shape = [ - req_bs * cfg_mult, - 1280, - ] - denoise_inputs = { - "sample": sfnp.device_array.for_device( - device, latents_shape, self.service.model_params.unet_dtype - ), - "encoder_hidden_states": sfnp.device_array.for_device( - device, hidden_states_shape, self.service.model_params.unet_dtype - ), - "text_embeds": sfnp.device_array.for_device( - device, text_embeds_shape, self.service.model_params.unet_dtype - ), - "guidance_scale": sfnp.device_array.for_device( - device, [req_bs], self.service.model_params.unet_dtype - ), - } - - # Send guidance scale to device. - gs_host = denoise_inputs["guidance_scale"].for_transfer() - for i in range(req_bs): - cfg_dim = i * cfg_mult - with gs_host.view(i).map(write=True, discard=True) as m: - # TODO: do this without numpy - np_arr = np.asarray(requests[i].guidance_scale, dtype="float16") - - m.fill(np_arr) - # Batch sample latent inputs on device. - req_samp = requests[i].sample - denoise_inputs["sample"].view(i).copy_from(req_samp) - - # Batch CLIP hidden states. - enc = requests[i].prompt_embeds - denoise_inputs["encoder_hidden_states"].view( - slice(cfg_dim, cfg_dim + cfg_mult) - ).copy_from(enc) - - # Batch CLIP text embeds. - temb = requests[i].text_embeds - denoise_inputs["text_embeds"].view( - slice(cfg_dim, cfg_dim + cfg_mult) - ).copy_from(temb) - - denoise_inputs["guidance_scale"].copy_from(gs_host) - - num_steps = sfnp.device_array.for_device(device, [1], sfnp.sint64) - ns_host = num_steps.for_transfer() - with ns_host.map(write=True) as m: - ns_host.items = [step_count] - num_steps.copy_from(ns_host) - - init_inputs = [ - denoise_inputs["sample"], - num_steps, - ] + cb = self.exec_request.command_buffer - # Initialize scheduler. logger.debug( "INVOKE %r", fns["init"], ) - (latents, time_ids, timesteps, sigmas) = await fns["init"]( - *init_inputs, fiber=self.fiber - ) - + (cb.latents, cb.time_ids, cb.timesteps, cb.sigmas,) = await fns[ + "init" + ](cb.sample, cb.num_steps, fiber=self.fiber) accum_step_duration = 0 # Accumulated duration for all steps for i, t in tqdm( - enumerate(range(step_count)), + enumerate(range(self.exec_request.steps)), disable=(not self.service.show_progress), desc=f"DENOISE (bs{req_bs})", ): start = time.time() - step = sfnp.device_array.for_device(device, [1], sfnp.sint64) - s_host = step.for_transfer() - with s_host.map(write=True) as m: - s_host.items = [i] - step.copy_from(s_host) - scale_inputs = [latents, step, timesteps, sigmas] + step = cb.steps_arr.view(i) logger.debug( "INVOKE %r", fns["scale"], ) - latent_model_input, t, sigma, next_sigma = await fns["scale"]( - *scale_inputs, fiber=self.fiber - ) - await device - - unet_inputs = [ - latent_model_input, - t, - denoise_inputs["encoder_hidden_states"], - denoise_inputs["text_embeds"], - time_ids, - denoise_inputs["guidance_scale"], - ] + (cb.latent_model_input, cb.t, cb.sigma, cb.next_sigma,) = await fns[ + "scale" + ](cb.latents, step, cb.timesteps, cb.sigmas, fiber=self.fiber) logger.debug( "INVOKE %r", fns["unet"], ) - (noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber) - - step_inputs = [noise_pred, latents, sigma, next_sigma] + (cb.noise_pred,) = await fns["unet"]( + cb.latent_model_input, + cb.t, + cb.prompt_embeds, + cb.text_embeds, + cb.time_ids, + cb.guidance_scale, + fiber=self.fiber, + ) logger.debug( "INVOKE %r", fns["step"], ) - (latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber) - latents.copy_from(latent_model_output) - + (cb.latents,) = await fns["step"]( + cb.noise_pred, cb.latents, cb.sigma, cb.next_sigma, fiber=self.fiber + ) duration = time.time() - start accum_step_duration += duration - average_step_duration = accum_step_duration / step_count + average_step_duration = accum_step_duration / self.exec_request.steps log_duration_str( average_step_duration, "denoise (UNet) single step average", req_bs ) - - for idx, req in enumerate(requests): - req.denoised_latents = sfnp.device_array.for_device( - device, latents_shape, self.service.model_params.vae_dtype - ) - req.denoised_latents.copy_from(latents.view(idx)) return - @measure(type="exec", task="decode (VAE)") - async def _decode(self, device, requests): - req_bs = len(requests) + async def _decode(self, device): + req_bs = self.exec_request.batch_size + cb = self.exec_request.command_buffer # Decode latents to images entrypoints = self.service.inference_functions[self.worker_index]["decode"] - if req_bs not in list(entrypoints.keys()): - for request in requests: - await self._decode(device, [request]) - return + assert req_bs in list(entrypoints.keys()) for bs, fn in entrypoints.items(): if bs == req_bs: break - latents_shape = [ - req_bs, - self.service.model_params.num_latents_channels, - requests[0].height // 8, - requests[0].width // 8, - ] - latents = sfnp.device_array.for_device( - device, latents_shape, self.service.model_params.vae_dtype - ) - for i in range(req_bs): - latents.view(i).copy_from(requests[i].denoised_latents) - - await device # Decode the denoised latents. logger.debug( "INVOKE %r: %s", fn, - "".join([f"\n 0: {latents.shape}"]), + "".join([f"\n 0: {cb.latents.shape}"]), ) - (image,) = await fn(latents, fiber=self.fiber) - - await device - images_shape = [ + (cb.images,) = await fn(cb.latents, fiber=self.fiber) + cb.images_host.copy_from(cb.images) + image_array = cb.images_host.items + dtype = image_array.typecode + if cb.images_host.dtype == sfnp.float16: + dtype = np.float16 + self.exec_request.image_array = np.frombuffer(image_array, dtype=dtype).reshape( req_bs, 3, - requests[0].height, - requests[0].width, - ] - image_shape = [ - req_bs, - 3, - requests[0].height, - requests[0].width, - ] - images_host = sfnp.device_array.for_host(device, images_shape, sfnp.float16) - images_host.copy_from(image) - await device - for idx, req in enumerate(requests): - image_array = images_host.view(idx).items - dtype = image_array.typecode - if images_host.dtype == sfnp.float16: - dtype = np.float16 - req.image_array = np.frombuffer(image_array, dtype=dtype).reshape( - *image_shape - ) + self.exec_request.height, + self.exec_request.width, + ) return - async def _postprocess(self, device, requests): + async def _postprocess(self, device): # Process output images - for req in requests: - # TODO: reimpl with sfnp - permuted = np.transpose(req.image_array, (0, 2, 3, 1))[0] - cast_image = (permuted * 255).round().astype("uint8") - image_bytes = Image.fromarray(cast_image).tobytes() - - image = base64.b64encode(image_bytes).decode("utf-8") - req.result_image = image + # TODO: reimpl with sfnp + permuted = np.transpose(self.exec_request.image_array, (0, 2, 3, 1))[0] + cast_image = (permuted * 255).round().astype("uint8") + image_bytes = Image.fromarray(cast_image).tobytes() + + image = base64.b64encode(image_bytes).decode("utf-8") + self.exec_request.result_image = image return + + +def initialize_command_buffer(fiber, model_params: ModelParams, batch_size: int = 1): + bs = batch_size + cfg_bs = batch_size * 2 + h = model_params.dims[0][0] + w = model_params.dims[0][1] + c = model_params.num_latents_channels + device = fiber.device(0) + + datas = { + # CLIP + "input_ids": [ + sfnp.device_array.for_device( + device, [bs, model_params.max_seq_len], sfnp.sint64 + ), + sfnp.device_array.for_device( + device, [bs, model_params.max_seq_len], sfnp.sint64 + ), + sfnp.device_array.for_device( + device, [bs, model_params.max_seq_len], sfnp.sint64 + ), + sfnp.device_array.for_device( + device, [bs, model_params.max_seq_len], sfnp.sint64 + ), + ], + # DENOISE + "prompt_embeds": sfnp.device_array.for_device( + device, [cfg_bs, model_params.max_seq_len, 2048], model_params.unet_dtype + ), + "text_embeds": sfnp.device_array.for_device( + device, [cfg_bs, 1280], model_params.unet_dtype + ), + "sample": sfnp.device_array.for_device( + device, [bs, c, h // 8, w // 8], model_params.unet_dtype + ), + "latents": sfnp.device_array.for_device( + device, [bs, c, h // 8, w // 8], model_params.unet_dtype + ), + "noise_pred": sfnp.device_array.for_device( + device, [bs, c, h // 8, w // 8], model_params.unet_dtype + ), + "num_steps": sfnp.device_array.for_device(device, [1], sfnp.sint64), + "steps_arr": sfnp.device_array.for_device(device, [100], sfnp.sint64), + "timesteps": sfnp.device_array.for_device(device, [100], sfnp.float32), + "sigmas": sfnp.device_array.for_device(device, [100], sfnp.float32), + "latent_model_input": sfnp.device_array.for_device( + device, [bs, c, h // 8, w // 8], model_params.unet_dtype + ), + "t": sfnp.device_array.for_device(device, [1], model_params.unet_dtype), + "sigma": sfnp.device_array.for_device(device, [1], model_params.unet_dtype), + "next_sigma": sfnp.device_array.for_device( + device, [1], model_params.unet_dtype + ), + "time_ids": sfnp.device_array.for_device( + device, [bs, 6], model_params.unet_dtype + ), + "guidance_scale": sfnp.device_array.for_device( + device, [bs], model_params.unet_dtype + ), + # VAE + "images": sfnp.device_array.for_device( + device, [bs, 3, h, w], model_params.vae_dtype + ), + "images_host": sfnp.device_array.for_host( + device, [bs, 3, h, w], model_params.vae_dtype + ), + } + + class ServiceCmdBuffer: + def __init__(self, d): + self.__dict__ = d + + cb = ServiceCmdBuffer(datas) + return cb diff --git a/shortfin/python/shortfin_apps/sd/python_pipe.py b/shortfin/python/shortfin_apps/sd/python_pipe.py new file mode 100644 index 000000000..c87ff831f --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/python_pipe.py @@ -0,0 +1,561 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any +import argparse +import logging +import asyncio +from pathlib import Path +import numpy as np +import math +import sys +import time +import os +import copy +import subprocess +from contextlib import asynccontextmanager +import uvicorn + +# Import first as it does dep checking and reporting. +from shortfin.interop.fastapi import FastAPIResponder +from shortfin.support.logging_setup import native_handler +import shortfin as sf + +from fastapi import FastAPI, Request, Response + +from .components.generate import GenerateImageProcess +from .components.messages import InferenceExecRequest, InferencePhase +from .components.config_struct import ModelParams +from .components.io_struct import GenerateReqInput +from .components.manager import SystemManager +from .components.service import GenerateService, InferenceExecutorProcess +from .components.tokenizer import Tokenizer + + +logger = logging.getLogger("shortfin-sd") +logger.addHandler(native_handler) +logger.propagate = False + +THIS_DIR = Path(__file__).parent + + +def get_configs( + model_config, + flagfile, + target, + artifacts_dir, + use_tuned=True, +): + # Returns one set of config artifacts. + modelname = "sdxl" + tuning_spec = None + cfg_builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "config_artifacts.py"), + f"--target={target}", + f"--output-dir={artifacts_dir}", + f"--model={modelname}", + ] + outs = subprocess.check_output(cfg_builder_args).decode() + outs_paths = outs.splitlines() + for i in outs_paths: + if "sdxl_config" in i and not model_config: + model_config = i + elif "flagfile" in i and not flagfile: + flagfile = i + elif "attention_and_matmul_spec" in i and use_tuned: + tuning_spec = i + + if use_tuned and tuning_spec: + tuning_spec = os.path.abspath(tuning_spec) + + return model_config, flagfile, tuning_spec + + +def get_modules( + target, + device, + model_config, + flagfile=None, + td_spec=None, + extra_compile_flags=[], + artifacts_dir=None, +): + # TODO: Move this out of server entrypoint + vmfbs = {"clip": [], "unet": [], "vae": [], "scheduler": []} + params = {"clip": [], "unet": [], "vae": []} + model_flags = copy.deepcopy(vmfbs) + model_flags["all"] = extra_compile_flags + + if flagfile: + with open(flagfile, "r") as f: + contents = [line.rstrip() for line in f] + flagged_model = "all" + for elem in contents: + match = [keyw in elem for keyw in model_flags.keys()] + if any(match): + flagged_model = elem + else: + model_flags[flagged_model].extend([elem]) + if td_spec: + model_flags["unet"].extend( + [f"--iree-codegen-transform-dialect-library={td_spec}"] + ) + + filenames = [] + for modelname in vmfbs.keys(): + ireec_args = model_flags["all"] + model_flags[modelname] + ireec_extra_args = " ".join(ireec_args) + builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "builders.py"), + f"--model-json={model_config}", + f"--target={target}", + f"--splat=False", + f"--build-preference=precompiled", + f"--output-dir={artifacts_dir}", + f"--model={modelname}", + f"--iree-hal-target-device={device}", + f"--iree-hip-target={target}", + f"--iree-compile-extra-args={ireec_extra_args}", + ] + logger.info(f"Preparing runtime artifacts for {modelname}...") + logger.debug( + "COMMAND LINE EQUIVALENT: " + " ".join([str(argn) for argn in builder_args]) + ) + output = subprocess.check_output(builder_args).decode() + + output_paths = output.splitlines() + filenames.extend(output_paths) + for name in filenames: + for key in vmfbs.keys(): + if key in name.lower(): + if any(x in name for x in [".irpa", ".safetensors", ".gguf"]): + params[key].extend([name]) + elif "vmfb" in name: + vmfbs[key].extend([name]) + return vmfbs, params + + +class MicroSDXLExecutor(sf.Process): + def __init__(self, args, service): + super().__init__(fiber=service.meta_fibers[0].fiber) + self.service = service + + self.args = args + self.exec = None + self.imgs = None + + async def run(self): + args = self.args + + # self.exec = InferenceExecRequest( + # args.prompt, + # args.neg_prompt, + # 1024, + # 1024, + # args.steps, + # args.guidance_scale, + # args.seed, + # ) + input_ids = [ + [ + np.ones([1, 64], dtype=np.int64), + np.ones([1, 64], dtype=np.int64), + np.ones([1, 64], dtype=np.int64), + np.ones([1, 64], dtype=np.int64), + ] + ] + sample = np.ones([1, 4, 128, 128], dtype=np.float16) + self.exec = InferenceExecRequest( + prompt=None, + neg_prompt=None, + input_ids=input_ids, + height=1024, + width=1024, + steps=args.steps, + guidance_scale=args.guidance_scale, + sample=sample, + ) + + self.exec.phases[InferencePhase.POSTPROCESS]["required"] = False + while len(self.service.idle_meta_fibers) == 0: + time.sleep(0.5) + print("All fibers busy...") + fiber = self.service.idle_meta_fibers.pop() + exec_process = InferenceExecutorProcess(self.service, fiber) + if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER: + self.service.idle_meta_fibers.append(fiber) + exec_process.exec_request = self.exec + exec_process.launch() + await asyncio.gather(exec_process) + imgs = [] + await self.exec.done + + imgs.append(exec_process.exec_request.image_array) + + self.imgs = imgs + return + + +class SDXLSampleProcessor: + def __init__(self, service): + self.service = service + self.max_procs = 2 + self.num_procs = 0 + self.imgs = [] + self.procs = set() + + def process(self, args): + proc = MicroSDXLExecutor(args, self.service) + self.num_procs += 1 + proc.launch() + self.procs.add(proc) + return + + def read(self): + items = set() + for proc in self.procs: + if proc.imgs is not None: + img = proc.imgs + self.procs.remove(proc) + self.num_procs -= 1 + return img + return None + + +def create_service( + model_params, + device, + tokenizers, + vmfbs, + params, + device_idx=None, + device_ids=[], + fibers_per_device=1, + isolation="per_call", + trace_execution=False, + amdgpu_async_allocations=False, +): + if device_idx is not None: + sysman = SystemManager(device, [device_idx], amdgpu_async_allocations) + else: + sysman = SystemManager(device, device_ids, amdgpu_async_allocations) + + sdxl_service = GenerateService( + name="sd", + sysman=sysman, + tokenizers=tokenizers, + model_params=model_params, + fibers_per_device=fibers_per_device, + workers_per_device=1, + prog_isolation=isolation, + show_progress=False, + trace_execution=trace_execution, + ) + for key, vmfblist in vmfbs.items(): + for vmfb in vmfblist: + sdxl_service.load_inference_module(vmfb, component=key) + for key, datasets in params.items(): + sdxl_service.load_inference_parameters( + *datasets, parameter_scope="model", component=key + ) + sdxl_service.start() + return sdxl_service + + +def prepare_service(args): + tokenizers = [] + for idx, tok_name in enumerate(args.tokenizers): + subfolder = f"tokenizer_{idx + 1}" if idx > 0 else "tokenizer" + tokenizers.append(Tokenizer.from_pretrained(tok_name, subfolder)) + model_config, flagfile, tuning_spec = get_configs( + args.model_config, + args.flagfile, + args.target, + args.artifacts_dir, + args.use_tuned, + ) + model_params = ModelParams.load_json(model_config) + vmfbs, params = get_modules( + args.target, + args.device, + model_config, + flagfile, + tuning_spec, + artifacts_dir=args.artifacts_dir, + ) + return model_params, tokenizers, vmfbs, params + + +class Main: + def __init__(self, sysman): + self.sysman = sysman + + def main(self, args): # queue + model_params, tokenizers, vmfbs, params = prepare_service(args) + shared_service = False + services = set() + if shared_service: + services.add( + create_service( + model_params, + args.device, + tokenizers, + vmfbs, + params, + trace_execution=args.trace_execution, + amdgpu_async_allocations=args.amdgpu_async_allocations, + ) + ) + else: + for idx, device in enumerate(self.sysman.ls.device_names): + services.add( + create_service( + model_params, + args.device, + tokenizers, + vmfbs, + params, + device_idx=idx, + trace_execution=args.trace_execution, + amdgpu_async_allocations=args.amdgpu_async_allocations, + ) + ) + procs = set() + procs_per_service = 2 + for service in services: + for i in range(procs_per_service): + sample_processor = SDXLSampleProcessor(service) + procs.add(sample_processor) + + samples = args.samples + queue = set() + # n sets of arguments into a queue + + for i in range(samples): + # Run until told to stop or queue exhaustion + # OR multiple dequeue threads pulling from queue + # read, instantiate, launch + # knob : concurrency control + queue.add(i) + + start = time.time() + imgs = [] + # Fire off jobs + while len(queue) > 0: + # round robin pop items from queue into executors + this_processor = procs.pop() + while this_processor.num_procs >= this_processor.max_procs: + procs.add(this_processor) + this_processor = procs.pop() + # Try reading and clearing out processes before checking again. + for proc in procs: + results = proc.read() + if results: + imgs.append(results) + print(f"{len(imgs)} samples received, of a total {samples}") + # Pop item from queue and initiate process. + queue.pop() + this_processor.process(args) + procs.add(this_processor) + + # Read responses + while len(imgs) < samples: + for proc in procs: + results = proc.read() + if results: + imgs.append(results) + print(f"{len(imgs)} samples received, of a total {samples}") + + print(f"Completed {samples} samples in {time.time() - start} seconds.") + return + + +def run_cli(argv): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" + ) + parser.add_argument( + "--device", + type=str, + required=True, + choices=["local-task", "hip", "amdgpu"], + help="Primary inferencing device", + ) + parser.add_argument( + "--target", + type=str, + required=False, + default="gfx942", + choices=["gfx942", "gfx1100", "gfx90a"], + help="Primary inferencing device LLVM target arch.", + ) + parser.add_argument( + "--device_ids", + type=str, + nargs="*", + default=None, + help="Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a sf device id like amdgpu:0:0@0", + ) + parser.add_argument( + "--tokenizers", + type=Path, + nargs="*", + default=[ + "stabilityai/stable-diffusion-xl-base-1.0", + "stabilityai/stable-diffusion-xl-base-1.0", + ], + help="HF repo from which to load tokenizer(s).", + ) + parser.add_argument( + "--model_config", + type=Path, + help="Path to the model config file. If None, defaults to i8 punet, batch size 1", + ) + parser.add_argument( + "--workers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--fibers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--isolation", + type=str, + default="per_call", + choices=["per_fiber", "per_call", "none"], + help="Concurrency control -- How to isolate programs.", + ) + parser.add_argument( + "--show_progress", + action="store_true", + help="enable tqdm progress for unet iterations.", + ) + parser.add_argument( + "--trace_execution", + action="store_true", + help="Enable tracing of program modules.", + ) + parser.add_argument( + "--amdgpu_async_allocations", + action="store_true", + help="Enable asynchronous allocations for amdgpu device contexts.", + ) + parser.add_argument( + "--splat", + action="store_true", + help="Use splat (empty) parameter files, usually for testing.", + ) + parser.add_argument( + "--build_preference", + type=str, + choices=["compile", "precompiled"], + default="precompiled", + help="Specify preference for builder artifact generation.", + ) + parser.add_argument( + "--compile_flags", + type=str, + nargs="*", + default=[], + help="extra compile flags for all compile actions. For fine-grained control, use flagfiles.", + ) + parser.add_argument( + "--flagfile", + type=Path, + help="Path to a flagfile to use for SDXL. If not specified, will use latest flagfile from azure.", + ) + parser.add_argument( + "--artifacts_dir", + type=Path, + default=None, + help="Path to local artifacts cache.", + ) + parser.add_argument( + "--tuning_spec", + type=str, + default=None, + help="Path to transform dialect spec if compiling an executable with tunings.", + ) + parser.add_argument( + "--use_tuned", + type=int, + default=1, + help="Use tunings for attention and matmul ops. 0 to disable.", + ) + parser.add_argument( + "--prompt", + type=str, + default="a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + help="Image generation prompt", + ) + parser.add_argument( + "--neg_prompt", + type=str, + default="Watermark, blurry, oversaturated, low resolution, pollution", + help="Image generation negative prompt", + ) + parser.add_argument( + "--steps", + type=int, + default="20", + help="Number of inference steps. More steps usually means a better image. Interactive only.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default="0.7", + help="Guidance scale for denoising.", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="RNG seed for image latents.", + ) + parser.add_argument( + "--samples", + type=int, + default=1, + help="Benchmark samples.", + ) + parser.add_argument( + "--max_concurrent_procs", + type=int, + default=16, + help="Maximum number of executor threads to run at any given time.", + ) + args = parser.parse_args(argv) + if not args.artifacts_dir: + home = Path.home() + artdir = home / ".cache" / "shark" + args.artifacts_dir = str(artdir) + else: + args.artifacts_dir = os.path.abspath(args.artifacts_dir) + + sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) + main = Main(sysman) + main.main(args) + + +if __name__ == "__main__": + logging.root.setLevel(logging.INFO) + run_cli( + sys.argv[1:], + )