-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[shortfin-sd] Add exports and support for scheduled unet, batch sizes. #972
Draft
monorimet
wants to merge
17
commits into
main
Choose a base branch
from
sdxl_exports
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
faf7103
Add torch exports for SDXL.
eagarvey-amd 846ee59
Finish builder for sdxl exports.
eagarvey-amd 9171346
(WIP) Integrate exports with sdxl builder.
eagarvey-amd b430247
Rework export procedure and program load around batch sizing, configs
eagarvey-amd 24047cb
Add separate scheduler exports and backwards compat, simplify builder…
eagarvey-amd 628825c
Rework torch version check.
eagarvey-amd cfb64f9
Delete sharktank/sharktank/torch_exports/sdxl/README.md
monorimet 75d3500
Reduce scheduler code, pipe through an inference option for scheduled…
eagarvey-amd 088e19d
Fix scheduled unet export.
eagarvey-amd a7bda08
Fixup builder force update, batch size >1 inference
eagarvey-amd f5c9204
Update example configs.
eagarvey-amd b033d4e
Fixups for harness.
eagarvey-amd 759b645
Enable FP8 attention punet variant.
eagarvey-amd 179301e
Resolve conflict
eagarvey-amd b4408da
Add urllib import
eagarvey-amd 986ba97
Revert some changes to builder.
eagarvey-amd 314e04b
Bump IREE to 3.3.0rc20250219
eagarvey-amd File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import importlib.util | ||
|
||
from .clip import * | ||
from .vae import * | ||
from .scheduler import * | ||
from .unet import * | ||
|
||
|
||
if spec := importlib.util.find_spec("diffusers") is None: | ||
raise ModuleNotFoundError("Diffusers not found.") | ||
|
||
if spec := importlib.util.find_spec("transformers") is None: | ||
raise ModuleNotFoundError("Transformers not found.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import os | ||
import sys | ||
|
||
from iree import runtime as ireert | ||
import iree.compiler as ireec | ||
from iree.compiler.ir import Context | ||
import numpy as np | ||
from iree.turbine.aot import * | ||
|
||
import torch | ||
from transformers import ( | ||
CLIPTextModel, | ||
CLIPTextModelWithProjection, | ||
CLIPTokenizer, | ||
CLIPTextConfig, | ||
) | ||
|
||
|
||
class PromptEncoderModel(torch.nn.Module): | ||
def __init__( | ||
self, | ||
hf_model_name, | ||
precision, | ||
batch_size=1, | ||
batch_input=False, | ||
): | ||
super().__init__() | ||
self.torch_dtype = torch.float16 if precision == "fp16" else torch.float32 | ||
config_1 = CLIPTextConfig.from_pretrained( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we grabbing the HF implementation for all these (except punet) models? We should have everything covered in our sharktank models that we should switch to |
||
hf_model_name, | ||
subfolder="text_encoder", | ||
) | ||
config_1._attn_implementation = "eager" | ||
config_2 = CLIPTextConfig.from_pretrained( | ||
hf_model_name, | ||
subfolder="text_encoder_2", | ||
) | ||
config_2._attn_implementation = "eager" | ||
self.text_encoder_model_1 = CLIPTextModel.from_pretrained( | ||
hf_model_name, | ||
config=config_1, | ||
subfolder="text_encoder", | ||
) | ||
self.text_encoder_model_2 = CLIPTextModelWithProjection.from_pretrained( | ||
hf_model_name, | ||
config=config_2, | ||
subfolder="text_encoder_2", | ||
) | ||
self.do_classifier_free_guidance = True | ||
self.batch_size = batch_size | ||
self.batch_input = batch_input | ||
|
||
def forward( | ||
self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2 | ||
): | ||
with torch.no_grad(): | ||
prompt_embeds_1 = self.text_encoder_model_1( | ||
text_input_ids_1, | ||
output_hidden_states=True, | ||
) | ||
prompt_embeds_2 = self.text_encoder_model_2( | ||
text_input_ids_2, | ||
output_hidden_states=True, | ||
) | ||
neg_prompt_embeds_1 = self.text_encoder_model_1( | ||
uncond_input_ids_1, | ||
output_hidden_states=True, | ||
) | ||
neg_prompt_embeds_2 = self.text_encoder_model_2( | ||
uncond_input_ids_2, | ||
output_hidden_states=True, | ||
) | ||
# We are only ALWAYS interested in the pooled output of the final text encoder | ||
pooled_prompt_embeds = prompt_embeds_2[0] | ||
neg_pooled_prompt_embeds = neg_prompt_embeds_2[0] | ||
|
||
prompt_embeds_list = [ | ||
prompt_embeds_1.hidden_states[-2], | ||
prompt_embeds_2.hidden_states[-2], | ||
] | ||
neg_prompt_embeds_list = [ | ||
neg_prompt_embeds_1.hidden_states[-2], | ||
neg_prompt_embeds_2.hidden_states[-2], | ||
] | ||
|
||
prompt_embeds = torch.cat(prompt_embeds_list, dim=-1) | ||
neg_prompt_embeds = torch.cat(neg_prompt_embeds_list, dim=-1) | ||
|
||
bs_embed, seq_len, _ = prompt_embeds.shape | ||
prompt_embeds = prompt_embeds.repeat(1, 1, 1) | ||
prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) | ||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( | ||
bs_embed * 1, -1 | ||
) | ||
if not self.batch_input: | ||
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) | ||
add_text_embeds = pooled_prompt_embeds | ||
if not self.batch_input: | ||
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) | ||
if self.do_classifier_free_guidance: | ||
if not self.batch_input: | ||
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( | ||
1, 1 | ||
).view(1, -1) | ||
neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1) | ||
neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1) | ||
if not self.batch_input: | ||
neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1) | ||
prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) | ||
if not self.batch_input: | ||
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( | ||
self.batch_size, 1 | ||
) | ||
add_text_embeds = torch.cat( | ||
[neg_pooled_prompt_embeds, add_text_embeds], dim=0 | ||
) | ||
add_text_embeds = add_text_embeds.to(self.torch_dtype) | ||
prompt_embeds = prompt_embeds.to(self.torch_dtype) | ||
return prompt_embeds, add_text_embeds | ||
|
||
|
||
@torch.no_grad() | ||
def get_clip_model_and_inputs( | ||
hf_model_name, | ||
max_length=64, | ||
precision="fp16", | ||
batch_size=1, | ||
batch_input=True, | ||
): | ||
prompt_encoder_module = PromptEncoderModel( | ||
hf_model_name, | ||
precision, | ||
batch_size=batch_size, | ||
batch_input=batch_input, | ||
) | ||
|
||
input_batchsize = 1 | ||
if batch_input: | ||
input_batchsize = batch_size | ||
|
||
if precision == "fp16": | ||
prompt_encoder_module = prompt_encoder_module.half() | ||
|
||
example_inputs = { | ||
"text_input_ids_1": torch.ones( | ||
(input_batchsize, max_length), dtype=torch.int64 | ||
), | ||
"text_input_ids_2": torch.ones( | ||
(input_batchsize, max_length), dtype=torch.int64 | ||
), | ||
"uncond_input_ids_1": torch.ones( | ||
(input_batchsize, max_length), dtype=torch.int64 | ||
), | ||
"uncond_input_ids_2": torch.ones( | ||
(input_batchsize, max_length), dtype=torch.int64 | ||
), | ||
} | ||
return prompt_encoder_module, example_inputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import torch | ||
from iree.turbine.aot import * | ||
from diffusers import ( | ||
EulerDiscreteScheduler, | ||
) | ||
|
||
|
||
class SchedulingModel(torch.nn.Module): | ||
def __init__( | ||
self, | ||
hf_model_name, | ||
scheduler, | ||
height, | ||
width, | ||
batch_size, | ||
dtype, | ||
): | ||
super().__init__() | ||
# For now, assumes SDXL implementation. May not need parametrization for other models, | ||
# but keeping hf_model_name in case. | ||
self.model = scheduler | ||
self.height = height | ||
self.width = width | ||
self.is_sd3 = False | ||
if "stable-diffusion-3" in hf_model_name: | ||
self.is_sd3 = True | ||
self.batch_size = batch_size | ||
# Whether this will be used with CFG-enabled pipeline. | ||
self.do_classifier_free_guidance = True | ||
timesteps = [torch.empty((100), dtype=dtype, requires_grad=False)] * 100 | ||
sigmas = [torch.empty((100), dtype=torch.float32, requires_grad=False)] * 100 | ||
for i in range(1, 100): | ||
self.model.set_timesteps(i) | ||
timesteps[i] = torch.nn.functional.pad( | ||
self.model.timesteps.clone().detach(), (0, 100 - i), "constant", 0 | ||
) | ||
sigmas[i] = torch.nn.functional.pad( | ||
self.model.sigmas.clone().detach(), (0, 100 - (i + 1)), "constant", 0 | ||
) | ||
self.timesteps = torch.stack(timesteps, dim=0).clone().detach() | ||
self.sigmas = torch.stack(sigmas, dim=0).clone().detach() | ||
self.model.is_scale_input_called = True | ||
self.dtype = dtype | ||
|
||
def initialize(self, sample, num_inference_steps): | ||
height = self.height | ||
width = self.width | ||
original_size = (height, width) | ||
target_size = (height, width) | ||
crops_coords_top_left = (0, 0) | ||
add_time_ids = list(original_size + crops_coords_top_left + target_size) | ||
add_time_ids = torch.tensor([add_time_ids], dtype=self.dtype) | ||
if self.do_classifier_free_guidance: | ||
add_time_ids = torch.cat([add_time_ids] * 2, dim=0) | ||
add_time_ids = add_time_ids.repeat(self.batch_size, 1).type(self.dtype) | ||
max_sigma = self.sigmas[num_inference_steps].max() | ||
init_noise_sigma = (max_sigma**2 + 1) ** 0.5 | ||
sample = sample * init_noise_sigma | ||
return ( | ||
sample.type(self.dtype), | ||
add_time_ids, | ||
self.timesteps[num_inference_steps].squeeze(0), | ||
self.sigmas[num_inference_steps].squeeze(0), | ||
) | ||
|
||
def scale_model_input(self, sample, i, timesteps, sigmas): | ||
sigma = sigmas[i] | ||
next_sigma = sigmas[i + 1] | ||
t = timesteps[i] | ||
latent_model_input = sample / ((sigma**2 + 1) ** 0.5) | ||
self.model.is_scale_input_called = True | ||
return ( | ||
latent_model_input.type(self.dtype), | ||
t.type(self.dtype), | ||
sigma.type(self.dtype), | ||
next_sigma.type(self.dtype), | ||
) | ||
|
||
def step(self, noise_pred, sample, sigma, next_sigma): | ||
sample = sample.to(torch.float32) | ||
gamma = 0.0 | ||
noise_pred = noise_pred.to(torch.float32) | ||
sigma_hat = sigma * (gamma + 1) | ||
pred_original_sample = sample - sigma_hat * noise_pred | ||
deriv = (sample - pred_original_sample) / sigma_hat | ||
dt = next_sigma - sigma_hat | ||
prev_sample = sample + deriv * dt | ||
return prev_sample.type(self.dtype) | ||
|
||
|
||
def get_scheduler(model_id, scheduler_id): | ||
if scheduler_id in SCHEDULER_MAP.keys(): | ||
scheduler = SCHEDULER_MAP[scheduler_id].from_pretrained( | ||
model_id, subfolder="scheduler" | ||
) | ||
elif all(x in scheduler_id for x in ["DPMSolverMultistep", "++"]): | ||
scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
model_id, subfolder="scheduler", algorithm_type="dpmsolver++" | ||
) | ||
else: | ||
raise ValueError(f"Scheduler {scheduler_id} not found.") | ||
if "Karras" in scheduler_id: | ||
scheduler.config.use_karras_sigmas = True | ||
|
||
return scheduler | ||
|
||
|
||
SCHEDULER_MAP = { | ||
"EulerDiscrete": EulerDiscreteScheduler, | ||
} | ||
|
||
torch_dtypes = { | ||
"fp16": torch.float16, | ||
"fp32": torch.float32, | ||
"bf16": torch.bfloat16, | ||
"float16": torch.float16, | ||
"float32": torch.float32, | ||
} | ||
|
||
|
||
def get_scheduler_model_and_inputs( | ||
hf_model_name, | ||
batch_size, | ||
height, | ||
width, | ||
precision, | ||
scheduler_id="EulerDiscrete", | ||
): | ||
dtype = torch_dtypes[precision] | ||
raw_scheduler = get_scheduler(hf_model_name, scheduler_id) | ||
scheduler = SchedulingModel( | ||
hf_model_name, raw_scheduler, height, width, batch_size, dtype | ||
) | ||
init_in, prep_in, step_in = get_sample_sched_inputs( | ||
batch_size, height, width, dtype | ||
) | ||
return scheduler, init_in, prep_in, step_in | ||
|
||
|
||
def get_sample_sched_inputs(batch_size, height, width, dtype): | ||
sample = ( | ||
batch_size, | ||
4, | ||
height // 8, | ||
width // 8, | ||
) | ||
noise_pred_shape = ( | ||
batch_size, | ||
4, | ||
height // 8, | ||
width // 8, | ||
) | ||
init_args = ( | ||
torch.rand(sample, dtype=dtype), | ||
torch.tensor([10], dtype=torch.int64), | ||
) | ||
prep_args = ( | ||
torch.rand(sample, dtype=dtype), | ||
torch.tensor([10], dtype=torch.int64), | ||
torch.rand(100, dtype=torch.float32), | ||
torch.rand(100, dtype=torch.float32), | ||
) | ||
step_args = [ | ||
torch.rand(noise_pred_shape, dtype=dtype), | ||
torch.rand(sample, dtype=dtype), | ||
torch.rand(1, dtype=dtype), | ||
torch.rand(1, dtype=dtype), | ||
] | ||
return init_args, prep_args, step_args |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason for this change?