Skip to content

Commit

Permalink
Rework torch version check.
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Feb 17, 2025
1 parent 8820ba5 commit 7ba05b8
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions shortfin/python/shortfin_apps/sd/components/exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def export_sdxl_model(
) -> ExportOutput:
import torch

def check_torch_version(begin: tuple, end: tuple):
version = torch.__version__.split("+")[0] # Remove any suffix like '+cu118'
major, minor, patch = map(int, version.split("."))
if not (begin <= (major, minor, patch) < end):
raise Warning(
f"Torch version is outside the supported range for some exports: {begin}-{end}"
)

decomp_list = [torch.ops.aten.logspace]
if decomp_attn == True:
decomp_list = [
Expand Down Expand Up @@ -66,11 +74,7 @@ def encode_prompts(
return module.forward(**inputs)

elif component in ["unet", "punet", "scheduled_unet"]:
t_ver = torch.__version__
if any(key in t_ver for key in ["2.6.", "2.3."]):
print(
"You have a torch version that is unstable for this export and may encounter export or compile-time issues: {t_ver}. The reccommended versions are 2.4.1 - 2.5.1"
)
check_torch_version((2, 4, 1), (2, 6, 0))
from sharktank.torch_exports.sdxl.unet import (
get_scheduled_unet_model_and_inputs,
get_punet_model_and_inputs,
Expand Down

0 comments on commit 7ba05b8

Please sign in to comment.