-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathSD_rebasin_merge.py
98 lines (76 loc) · 3.73 KB
/
SD_rebasin_merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import argparse
import torch
import os
from weight_matching import sdunet_permutation_spec, weight_matching, apply_permutation
parser = argparse.ArgumentParser(description= "Merge two stable diffusion models with git re-basin")
parser.add_argument("--model_a", type=str, help="Path to model a")
parser.add_argument("--model_b", type=str, help="Path to model b")
parser.add_argument("--device", type=str, help="Device to use, defaults to cpu", default="cpu", required=False)
parser.add_argument("--output", type=str, help="Output file name, without extension", default="merged", required=False)
parser.add_argument("--usefp16", type=str, help="Whether to use half precision", default=True, required=False)
parser.add_argument("--alpha", type=str, help="Ratio of model A to B", default="0.5", required=False)
parser.add_argument("--iterations", type=str, help="Number of steps to take before reaching alpha", default="10", required=False)
args = parser.parse_args()
device = args.device
def flatten_params(model):
return model["state_dict"]
model_a = torch.load(args.model_a, map_location=device)
model_b = torch.load(args.model_b, map_location=device)
theta_0 = model_a["state_dict"]
theta_1 = model_b["state_dict"]
alpha = float(args.alpha)
iterations = int(args.iterations)
step = alpha/iterations
permutation_spec = sdunet_permutation_spec()
special_keys = ["first_stage_model.decoder.norm_out.weight", "first_stage_model.decoder.norm_out.bias", "first_stage_model.encoder.norm_out.weight",
"first_stage_model.encoder.norm_out.bias", "model.diffusion_model.out.0.weight", "model.diffusion_model.out.0.bias"]
if args.usefp16:
print("Using half precision")
else:
print("Using full precision")
for x in range(iterations):
print(f"""
---------------------
ITERATION {x+1}
---------------------
""")
# In order to reach a certain alpha value with a given number of steps,
# You have to calculate an alpha for each individual iteration
if x > 0:
new_alpha = 1 - (1 - step*(1+x)) / (1 - step*(x))
else:
new_alpha = step
print(f"new alpha = {new_alpha}\n")
theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
if x == 0:
for key in theta_1.keys():
if "model" in key and key not in theta_0:
theta_0[key] = theta_1[key]
print("FINDING PERMUTATIONS")
# Replace theta_0 with a permutated version using model A and B
first_permutation, y = weight_matching(permutation_spec, flatten_params(model_a), theta_0, usefp16=args.usefp16)
theta_0 = apply_permutation(permutation_spec, first_permutation, theta_0)
second_permutation, z = weight_matching(permutation_spec, flatten_params(model_b), theta_0, usefp16=args.usefp16)
theta_3= apply_permutation(permutation_spec, second_permutation, theta_0)
new_alpha = torch.nn.functional.normalize(torch.sigmoid(torch.Tensor([y, z])), p=1, dim=0).tolist()[0]
# Weighted sum of the permutations
for key in special_keys:
theta_0[key] = (1 - new_alpha) * (theta_0[key]) + (new_alpha) * (theta_3[key])
output_file = f'{args.output}.ckpt'
# check if output file already exists, ask to overwrite
if os.path.isfile(output_file):
print("Output file already exists. Overwrite? (y/n)")
while True:
overwrite = input()
if overwrite == "y":
break
elif overwrite == "n":
print("Exiting...")
exit()
else:
print("Please enter y or n")
print("\nSaving...")
torch.save({
"state_dict": theta_0
}, output_file)
print("Done!")