diff --git a/shortfin/python/shortfin_apps/sd/components/exports.py b/shortfin/python/shortfin_apps/sd/components/exports.py index 7051b327d..c2eae4304 100644 --- a/shortfin/python/shortfin_apps/sd/components/exports.py +++ b/shortfin/python/shortfin_apps/sd/components/exports.py @@ -23,6 +23,14 @@ def export_sdxl_model( ) -> ExportOutput: import torch + def check_torch_version(begin: tuple, end: tuple): + version = torch.__version__.split("+")[0] # Remove any suffix like '+cu118' + major, minor, patch = map(int, version.split(".")) + if not (begin <= (major, minor, patch) < end): + raise Warning( + f"Torch version is outside the supported range for some exports: {begin}-{end}" + ) + decomp_list = [torch.ops.aten.logspace] if decomp_attn == True: decomp_list = [ @@ -66,11 +74,7 @@ def encode_prompts( return module.forward(**inputs) elif component in ["unet", "punet", "scheduled_unet"]: - t_ver = torch.__version__ - if any(key in t_ver for key in ["2.6.", "2.3."]): - print( - "You have a torch version that is unstable for this export and may encounter export or compile-time issues: {t_ver}. The reccommended versions are 2.4.1 - 2.5.1" - ) + check_torch_version((2, 4, 1), (2, 6, 0)) from sharktank.torch_exports.sdxl.unet import ( get_scheduled_unet_model_and_inputs, get_punet_model_and_inputs,