Skip to content

Commit

Permalink
Rework export procedure and program load around batch sizing, configs
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Feb 16, 2025
1 parent c23517c commit 8c59caf
Show file tree
Hide file tree
Showing 8 changed files with 422 additions and 343 deletions.
10 changes: 1 addition & 9 deletions sharktank/sharktank/torch_exports/sdxl/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,7 @@ def get_scheduled_unet_model_and_inputs(
torch.rand(100, dtype=torch.float32),
torch.rand(100, dtype=torch.float32),
)
standalone_unet_inputs = {
"sample": torch.rand(sample, dtype=dtype),
"timestep": torch.zeros(1, dtype=dtype),
"encoder_hidden_states": torch.rand(prompt_embeds_shape, dtype=dtype),
"text_embeds": torch.rand(text_embeds_shape, dtype=dtype),
"time_ids": torch.zeros(time_ids_shape, dtype=dtype),
"guidance_scale": torch.tensor([7.5], dtype=dtype),
}
return model, init_inputs, forward_inputs, standalone_unet_inputs
return model, init_inputs, forward_inputs


@torch.no_grad()
Expand Down
Loading

0 comments on commit 8c59caf

Please sign in to comment.