-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdenoise.py
123 lines (101 loc) · 4.16 KB
/
denoise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from config import *
from diffusion import *
from dataset import tensor_to_pil
from dit import Dit
from PIL import Image
import os
def save_combined_image(tensors, path, grid_size, img_size):
"""
将一组张量图像汇总到一个大图像中并保存。
Args:
tensors (list of torch.Tensor): 要汇总的张量列表。
path (str): 保存图像的路径。
grid_size (tuple): (行数, 列数)。
img_size (tuple): 每个小图像的大小 (宽, 高)。
"""
rows, cols = grid_size
width, height = img_size
combined_image = Image.new('RGB', (cols * width, rows * height))
for idx, tensor in enumerate(tensors):
img = tensor_to_pil(tensor)
row = idx // cols
col = idx % cols
combined_image.paste(img, (col * width, row * height))
combined_image.save(path)
def backward_denoise(model, batch_x_t, batch_cls):
steps=[batch_x_t,]
global alpha,alpha_cum_product,variance
model=model.to(device)
batch_x_t=batch_x_t.to(device)
batch_cls = batch_cls.to(device)
alpha=alpha.to(device)
alpha_cum_product=alpha_cum_product.to(device)
variance=variance.to(device)
# 应该是由于BN层mean,std导致的eval效果不好,先不开启了
#model.eval()
with torch.no_grad():
for t in range(timestep-1,-1,-1):
batch_t=torch.full((batch_x_t.size(0),),t).to(device) #[999,999,....]
# 预测x_t时刻的噪音
batch_predict_noise_t=model(batch_x_t,batch_t, batch_cls)
# 生成t-1时刻的图像
shape=(batch_x_t.size(0),1,1,1)
batch_mean_t=1/torch.sqrt(alpha[batch_t].view(*shape))* \
(
batch_x_t-
(1-alpha[batch_t].view(*shape))/torch.sqrt(1-alpha_cum_product[batch_t].view(*shape))*batch_predict_noise_t
)
if t!=0:
batch_x_t=batch_mean_t+ \
torch.randn_like(batch_x_t)* \
torch.sqrt(variance[batch_t].view(*shape))
else:
batch_x_t=batch_mean_t
batch_x_t=torch.clamp(batch_x_t, -1.0, 1.0).detach()
steps.append(batch_x_t)
return steps
if __name__ == '__main__':
# 加载状态字典
state_dict = torch.load("/mnt/diffusion_transformer/models/model_epoch_490.pt")
model = model = Dit(image_size = 48, patch_size = 4, input_channel = 1, embedding_size = 128, dit_block_num = 4, num_heads = 4, label_num = 10, mlp_ratio = 4).to(device)
# 加载状态字典到模型
model.load_state_dict(state_dict)
# 生成噪音图
batch_size = 10
image_size = 48 # 确保定义 image_size
# xt
batch_x_t = torch.randn(size=(batch_size, 1, image_size, image_size)) # (10, 1, 48, 48)
# 引导词
batch_cls = torch.arange(start = 0, end = 10, dtype = torch.long)
# 逐步去噪得到原图
steps = backward_denoise(model, batch_x_t, batch_cls)
# 创建输出目录
output_dir = "output_images"
os.makedirs(output_dir, exist_ok=True)
# 绘制数量
num_imgs = 20
# 保存每个批次的汇总图像
combined_images = []
for b in range(batch_size):
img_list = []
for i in range(0, num_imgs):
idx = int(timestep / num_imgs) * (i + 1)
# 像素值还原到 [0, 1]
final_img = (steps[idx][b].to('cpu') + 1) / 2
img_list.append(final_img)
# 保存汇总图像
grid_size = (1, num_imgs)
img_path = os.path.join(output_dir, f"batch_{b}_combined.png")
save_combined_image(img_list, img_path, grid_size, (image_size, image_size))
combined_images.append(img_path)
# 拼接所有汇总图像
final_image_width = image_size * num_imgs
final_image_height = image_size * batch_size
final_image = Image.new('RGB', (final_image_width, final_image_height))
for i, img_path in enumerate(combined_images):
img = Image.open(img_path)
final_image.paste(img, (0, i * image_size))
# 保存最终的大图像
final_image_path = os.path.join(output_dir, "final_combined_image.png")
final_image.save(final_image_path)