Skip to content

Commit

Permalink
Add separate scheduler exports and backwards compat, simplify builder…
Browse files Browse the repository at this point in the history
… utils
  • Loading branch information
eagarvey-amd committed Feb 17, 2025
1 parent 8c59caf commit 8820ba5
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 50 deletions.
27 changes: 27 additions & 0 deletions sharktank/sharktank/torch_exports/sdxl/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ def get_scheduler(model_id, scheduler_id):
"DPMSolverSDEKarras": DPMSolverSDEScheduler,
}

torch_dtypes = {
"fp16": torch.float16,
"fp32": torch.float32,
"bf16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}


def get_scheduler_model_and_inputs(
hf_model_name,
batch_size,
height,
width,
precision,
scheduler_id="EulerDiscrete",
):
dtype = torch_dtypes[precision]
raw_scheduler = get_scheduler(hf_model_name, scheduler_id)
scheduler = SchedulingModel(
hf_model_name, raw_scheduler, height, width, batch_size, dtype
)
init_in, prep_in, step_in = get_sample_sched_inputs(
batch_size, height, width, dtype
)
return scheduler, init_in, prep_in, step_in


def get_sample_sched_inputs(batch_size, height, width, dtype):
sample = (
Expand Down
12 changes: 6 additions & 6 deletions sharktank/sharktank/torch_exports/sdxl/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,16 @@ def get_vae_model_and_inputs(
vae_model = VaeModel(hf_model_name, custom_vae=custom_vae).to(dtype=dtype)
input_image_shape = (batch_size, 3, height, width)
input_latents_shape = (batch_size, num_channels, height // 8, width // 8)
encode_args = {
"image": torch.rand(
encode_args = [
torch.rand(
input_image_shape,
dtype=dtype,
)
}
decode_args = {
"latents": torch.empty(
]
decode_args = [
torch.rand(
input_latents_shape,
dtype=dtype,
),
}
]
return vae_model, encode_args, decode_args
34 changes: 22 additions & 12 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_file_stems(model_params: ModelParams) -> list[str]:
[modname],
]
bsizes = []
for bs in model_params.batch_sizes[modname]:
for bs in model_params.batch_sizes[mod]:
bsizes.extend([f"bs{bs}"])
ord_params.extend([bsizes])
if mod in ["unet", "clip"]:
Expand Down Expand Up @@ -225,7 +225,13 @@ def needs_compile(filename, target, ctx) -> bool:


def get_cached(filename, ctx, namespace) -> BuildFile:
return ctx.allocate_file(filename, namespace=namespace)
if filename is None:
return None
try:
cached_file = ctx.allocate_file(filename, namespace=namespace)
except RuntimeError:
cached_file = ctx.file(filename)
return cached_file


def is_valid_size(file_path, url) -> bool:
Expand Down Expand Up @@ -299,7 +305,7 @@ def parse_mlir_name(mlir_path):
if dims_match:
height = int(dims_match.group(1))
width = int(dims_match.group(2))
decomp_attn = False
decomp_attn = False if "unet" in mlir_path else True
else:
height = None
width = None
Expand Down Expand Up @@ -358,19 +364,21 @@ def sdxl(
if build_preference == "export":
for idx, mlir_path in enumerate(mlir_filenames):
# If generating multiple MLIR, we only save the weights the first time.
if idx == 0 and not os.path.exists(params_filepath):
needs_gen_params = False
if not params_filepath:
weights_path = None
elif idx == 0 and not os.path.exists(params_filepath):
weights_path = params_filepath
safe_params_access = True
needs_gen_params = True
elif "punet_dataset" in params_filename:
# We need the path for punet export.
weights_path = params_filepath
safe_params_access = False
else:
weights_path = None
safe_params_access = True

if (
needs_file(mlir_path, ctx)
or not os.path.exists(params_filepath)
or needs_gen_params
or force_update in [True, "True"]
):
(
Expand All @@ -394,7 +402,7 @@ def sdxl(
external_weights_file=weights_path,
decomp_attn=decomp_attn,
name=mlir_path.split(".mlir")[0],
out_of_process=False,
out_of_process=True,
)
else:
get_cached(mlir_path, ctx, FileNamespace.GEN)
Expand All @@ -405,13 +413,13 @@ def sdxl(
if update or needs_file(f, ctx, url):
fetch_http(name=f, url=url)
else:
get_cached(f, ctx)
get_cached(f, ctx, FileNamespace.GEN)
params_urls = get_url_map([params_filename], SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
if needs_file(f, ctx, url):
fetch_http_check_size(name=f, url=url)
else:
get_cached(f, ctx)
get_cached(f, ctx, FileNamespace.GEN)
if build_preference != "precompiled":
for idx, f in enumerate(copy.deepcopy(vmfb_filenames)):
# We return .vmfb file stems for the compile builder.
Expand All @@ -435,7 +443,9 @@ def sdxl(
else:
get_cached(f, ctx, FileNamespace.GEN)

filenames = [*vmfb_filenames, params_filename, *mlir_filenames]
filenames = [*vmfb_filenames, *mlir_filenames]
if params_filename:
filenames.append(params_filename)
return filenames


Expand Down
6 changes: 6 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ class ModelParams:
# ABI of the module.
module_abi_version: int = 1

@property
def all_batch_sizes(self) -> list:
bs_lists = list(self.batch_sizes.values())
union = set.union(*[set(list) for list in bs_lists])
return union

@staticmethod
def load_json(path: Path | str):
with open(path, "rt") as f:
Expand Down
35 changes: 33 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,38 @@ def run_forward(
return export(
model, kwargs=sample_forward_inputs, module_name="compiled_punet"
)
elif component == "scheduler":
module_name = "compiled_scheduler"
from sharktank.torch_exports.sdxl.scheduler import (
get_scheduler_model_and_inputs,
)

model, init_args, prep_args, step_args = get_scheduler_model_and_inputs(
hf_model_name,
batch_size,
height,
width,
precision,
)
fxb = FxProgramsBuilder(model)

@fxb.export_program(
args=(init_args,),
)
def run_initialize(module, sample):
return module.initialize(*sample)

@fxb.export_program(
args=(prep_args,),
)
def run_scale(module, inputs):
return module.scale_model_input(*inputs)

@fxb.export_program(
args=(step_args,),
)
def run_step(module, inputs):
return module.step(*inputs)

elif component == "vae":
from sharktank.torch_exports.sdxl.vae import get_vae_model_and_inputs
Expand All @@ -128,7 +160,6 @@ def run_forward(
model, encode_args, decode_args = get_vae_model_and_inputs(
hf_model_name, height, width, precision=precision, batch_size=batch_size
)
model.to("cpu")
fxb = FxProgramsBuilder(model)

@fxb.export_program(
Expand All @@ -138,7 +169,7 @@ def decode(
module,
inputs,
):
return module.decode(**inputs)
return module.decode(*inputs)

else:
raise ValueError("Unimplemented: ", component)
Expand Down
60 changes: 31 additions & 29 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,11 @@ def equip_fiber(self, fiber, idx, worker_idx):
MetaFiber = namedtuple(
"MetaFiber", ["fiber", "idx", "worker_idx", "device", "command_buffers"]
)
cbs_per_fiber = 1
cb_sets_per_fiber = 1
cbs = []
for _ in range(cbs_per_fiber):
for batch_size in self.model_params.all_batch_sizes:
cbs.append(
initialize_command_buffer(fiber, self.model_params, batch_size)
)
for _ in range(cb_sets_per_fiber):
for bs in self.model_params.all_batch_sizes:
cbs.append(initialize_command_buffer(fiber, self.model_params, bs=bs))

return MetaFiber(fiber, idx, worker_idx, fiber.device(0), cbs)

Expand Down Expand Up @@ -195,7 +193,8 @@ def start(self):
self.inference_functions[worker_idx]["encode"][bs] = {}
fn_dest = self.inference_functions[worker_idx]["encode"][bs]
elif submodel in ["unet", "scheduled_unet", "scheduler"]:
self.inference_functions[worker_idx]["denoise"][bs] = {}
if not self.inference_functions[worker_idx]["denoise"].get(bs):
self.inference_functions[worker_idx]["denoise"][bs] = {}
fn_dest = self.inference_functions[worker_idx]["denoise"][bs]
elif submodel == "vae":
self.inference_functions[worker_idx]["decode"][bs] = {}
Expand All @@ -213,8 +212,8 @@ def shutdown(self):
if self.use_batcher:
self.batcher.shutdown()
del self.batcher
del self.idle_fibers
del self.fibers
del self.idle_meta_fibers
del self.meta_fibers
del self.workers
gc.collect()

Expand Down Expand Up @@ -381,7 +380,7 @@ def __init__(

def assign_command_buffer(self, request: InferenceExecRequest):
for cb in self.meta_fiber.command_buffers:
if cb.sample.shape[0] == self.exec_request.batch_size:
if cb.batch_size == self.exec_request.batch_size:
self.exec_request.set_command_buffer(cb)
self.meta_fiber.command_buffers.remove(cb)
return
Expand Down Expand Up @@ -476,17 +475,19 @@ async def _encode(self, device):
req_bs = self.exec_request.batch_size
entrypoints = self.service.inference_functions[self.worker_index]["encode"]
assert req_bs in list(entrypoints.keys())
for bs, fn in entrypoints.items():
for bs, fns in entrypoints.items():
if bs == req_bs:
break
cb = self.exec_request.command_buffer
# Encode tokenized inputs.
logger.debug(
"INVOKE %r: %s",
fn,
fns["encode_prompts"],
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(cb.input_ids)]),
)
cb.prompt_embeds, cb.text_embeds = await fn(*cb.input_ids, fiber=self.fiber)
cb.prompt_embeds, cb.text_embeds = await fns["encode_prompts"](
*cb.input_ids, fiber=self.fiber
)
return

async def _denoise(self, device):
Expand All @@ -501,12 +502,13 @@ async def _denoise(self, device):

logger.debug(
"INVOKE %r",
fns["init"],
fns["run_initialize"],
)
(cb.latents, cb.time_ids, cb.timesteps, cb.sigmas,) = await fns[
"init"
"run_initialize"
](cb.sample, cb.num_steps, fiber=self.fiber)
accum_step_duration = 0 # Accumulated duration for all steps

for i, t in tqdm(
enumerate(range(self.exec_request.steps)),
disable=(not self.service.show_progress),
Expand All @@ -516,16 +518,16 @@ async def _denoise(self, device):
step = cb.steps_arr.view(i)
logger.debug(
"INVOKE %r",
fns["scale"],
fns["run_scale"],
)
(cb.latent_model_input, cb.t, cb.sigma, cb.next_sigma,) = await fns[
"scale"
"run_scale"
](cb.latents, step, cb.timesteps, cb.sigmas, fiber=self.fiber)
logger.debug(
"INVOKE %r",
fns[self.denoise_mod],
fns["main"],
)
(cb.noise_pred,) = await fns["unet"](
(cb.noise_pred,) = await fns["main"](
cb.latent_model_input,
cb.t,
cb.prompt_embeds,
Expand All @@ -536,9 +538,9 @@ async def _denoise(self, device):
)
logger.debug(
"INVOKE %r",
fns["step"],
fns["run_step"],
)
(cb.latents,) = await fns["step"](
(cb.latents,) = await fns["run_step"](
cb.noise_pred, cb.latents, cb.sigma, cb.next_sigma, fiber=self.fiber
)
duration = time.time() - start
Expand All @@ -547,6 +549,7 @@ async def _denoise(self, device):
log_duration_str(
average_step_duration, "denoise (UNet) single step average", req_bs
)

return

async def _decode(self, device):
Expand All @@ -555,17 +558,17 @@ async def _decode(self, device):
# Decode latents to images
entrypoints = self.service.inference_functions[self.worker_index]["decode"]
assert req_bs in list(entrypoints.keys())
for bs, fn in entrypoints.items():
for bs, fns in entrypoints.items():
if bs == req_bs:
break

# Decode the denoised latents.
logger.debug(
"INVOKE %r: %s",
fn,
fns["decode"],
"".join([f"\n 0: {cb.latents.shape}"]),
)
(cb.images,) = await fn(cb.latents, fiber=self.fiber)
(cb.images,) = await fns["decode"](cb.latents, fiber=self.fiber)
cb.images_host.copy_from(cb.images)
image_array = cb.images_host.items
dtype = image_array.typecode
Expand All @@ -591,14 +594,12 @@ async def _postprocess(self, device):
return


def initialize_command_buffer(fiber, model_params: ModelParams, batch_size: int = 1):
bs = batch_size
cfg_bs = batch_size * 2
def initialize_command_buffer(fiber, model_params: ModelParams, bs: int = 1):
device = fiber.device(0)
h = model_params.dims[0][0]
w = model_params.dims[0][1]
c = model_params.num_latents_channels
device = fiber.device(0)

cfg_bs = bs * 2
datas = {
# CLIP
"input_ids": [
Expand Down Expand Up @@ -661,6 +662,7 @@ def initialize_command_buffer(fiber, model_params: ModelParams, batch_size: int
class ServiceCmdBuffer:
def __init__(self, d):
self.__dict__ = d
self.batch_size = bs

cb = ServiceCmdBuffer(datas)
return cb
Loading

0 comments on commit 8820ba5

Please sign in to comment.