Skip to content

Commit

Permalink
Add proper support for LoRAs when using torch.compile.
Browse files Browse the repository at this point in the history
Finally.
Thanks to comfyanonymous/ComfyUI#6638 to use as a guide for add_patches function
  • Loading branch information
Panchovix committed Feb 9, 2025
1 parent 9bd652f commit a6de2b6
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 8 deletions.
26 changes: 23 additions & 3 deletions ldm_patched/modules/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,18 @@ def model_dtype(self):

def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
model_sd = self.model.state_dict()
for k in patches:
if k in self.model_keys:
p.add(k)
current_patches = self.patches.get(k, [])
# Check if key needs to be modified for compiled model
patch_key = k
if k.startswith("diffusion_model.") and hasattr(self.model, "compile_settings"):
patch_key = k.replace("diffusion_model.", "diffusion_model._orig_mod.")

current_patches = self.patches.get(patch_key, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
self.patches[patch_key] = current_patches

self.patches_uuid = uuid.uuid4()
return list(p)
Expand Down Expand Up @@ -287,7 +293,21 @@ def patch_weight_to_device(self, key, device_to=None):

def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = ldm_patched.modules.utils.set_attr(self.model, k, self.object_patches[k])
value = self.object_patches[k]
if k == 'diffusion_model':
# Special handling for the main diffusion model
if hasattr(self.model, k):
# Direct replacement for model attribute
setattr(self.model, k, value)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = getattr(self.model, k)
continue

# Handle other compiled models and function objects
if hasattr(value, '_orig_mod') or callable(value):
old = ldm_patched.modules.utils.set_attr_raw(self.model, k, value)
else:
old = ldm_patched.modules.utils.set_attr(self.model, k, value)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old

Expand Down
34 changes: 34 additions & 0 deletions ldm_patched/modules/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def load_clip_weights(model, sd):
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'):
model_flag = type(model.model).__name__ if model is not None else 'default'

# Store compilation state
was_compiled = model is not None and hasattr(model.model, "compile_settings")
if was_compiled:
compile_settings = model.model.compile_settings
patch_keys = list(model.object_patches_backup.keys())
for k in patch_keys:
ldm_patched.modules.utils.set_attr(model.model, k, model.object_patches_backup[k])

# Only build key maps for components we'll actually use
key_map = {}
if model is not None and strength_model != 0:
Expand Down Expand Up @@ -88,6 +96,32 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
new_clip = clip
loaded_keys_clip = set()

# Recompile model if it was compiled
if was_compiled and new_modelpatcher is not None:
if not new_modelpatcher.is_model_compiled():
print(f"Recompiling model with LoRA patches...")
try:
if hasattr(new_modelpatcher.model, "diffusion_model"):
real_model = new_modelpatcher.model.diffusion_model
else:
real_model = new_modelpatcher.model

compiled_model = torch.compile(
model=real_model,
**compile_settings
)

if hasattr(new_modelpatcher.model, "diffusion_model"):
new_modelpatcher.add_object_patch('diffusion_model', compiled_model)
else:
new_modelpatcher.model = compiled_model

new_modelpatcher.model.compile_settings = compile_settings
print("Model recompilation successful")
except Exception as e:
print(f"Warning: Failed to recompile model with error: {str(e)}")
print("Falling back to uncompiled model")

# Only log if we actually loaded something
if loaded_keys_unet or loaded_keys_clip:
total_loaded_keys = len(loaded_keys_unet) + len(loaded_keys_clip)
Expand Down
8 changes: 7 additions & 1 deletion ldm_patched/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,13 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))

# Handle compiled models
if hasattr(value, '_orig_mod'):
setattr(obj, attrs[-1], value)
else:
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))

del prev

def set_attr_param(obj, attr, value):
Expand Down
21 changes: 17 additions & 4 deletions modules_forge/unet_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,18 @@ def compile_model(self, backend="inductor"):
"""Compile the self model using torch.compile"""
if not hasattr(torch, 'compile'):
print("torch.compile not available - requires PyTorch 2.0 or newer")
return
return False

# Check if already compiled
if self.is_model_compiled():
print("Model is already compiled, skipping compilation")
return True

try:
torch_version = torch.__version__.split('.')
if int(torch_version[0]) < 2:
print(f"torch.compile requires PyTorch 2.0 or newer. Current version: {torch.__version__}")
return
return False

import torch._dynamo as dynamo
dynamo.config.suppress_errors = True
Expand Down Expand Up @@ -139,23 +144,31 @@ def compile_model(self, backend="inductor"):
print(f"Compiling model using torch.compile with settings: {compile_settings}")

# Store settings for later recompilation if needed
real_model.compile_settings = compile_settings
self.model.compile_settings = compile_settings

try:
compiled_model = torch.compile(real_model, **compile_settings)
# Store the compiled model using object patch
if hasattr(self.model, 'diffusion_model'):
self.model.diffusion_model = compiled_model
self.add_object_patch('diffusion_model', compiled_model)
else:
self.model = compiled_model
print("Model compilation successful with dynamic shapes support")
self.compiled = True
return True
except Exception as e:
print(f"Warning: torch.compile failed with error: {str(e)}")
print("Falling back to uncompiled model")
return False

except Exception as e:
print(f"Error during model compilation: {str(e)}")
return False

def is_model_compiled(self):
if not hasattr(self.model, 'diffusion_model'):
return hasattr(self.model, '_orig_mod')
return hasattr(self.model.diffusion_model, '_orig_mod')

def add_patched_controlnet(self, cnet):
cnet.set_previous_controlnet(self.controlnet_linked_list)
Expand Down

0 comments on commit a6de2b6

Please sign in to comment.