-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathdiffpure_guided.py
89 lines (68 loc) · 3.47 KB
/
diffpure_guided.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for DiffPure. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
import os
import random
import torch
import torchvision.utils as tvu
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
class GuidedDiffusion(torch.nn.Module):
def __init__(self, args, config, device=None, model_dir='pretrained/guided_diffusion'):
super().__init__()
self.args = args
self.config = config
if device is None:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.device = device
# load model
model_config = model_and_diffusion_defaults()
model_config.update(vars(self.config.model))
print(f'model_config: {model_config}')
model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load(f'{model_dir}/256x256_diffusion_uncond.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(self.device)
if model_config['use_fp16']:
model.convert_to_fp16()
self.model = model
self.diffusion = diffusion
self.betas = torch.from_numpy(diffusion.betas).float().to(self.device)
def image_editing_sample(self, img, bs_id=0, tag=None):
with torch.no_grad():
assert isinstance(img, torch.Tensor)
batch_size = img.shape[0]
if tag is None:
tag = 'rnd' + str(random.randint(0, 10000))
out_dir = os.path.join(self.args.log_dir, 'bs' + str(bs_id) + '_' + tag)
assert img.ndim == 4, img.ndim
img = img.to(self.device)
x0 = img
if bs_id < 2:
os.makedirs(out_dir, exist_ok=True)
tvu.save_image((x0 + 1) * 0.5, os.path.join(out_dir, f'original_input.png'))
xs = []
for it in range(self.args.sample_step):
e = torch.randn_like(x0)
total_noise_levels = self.args.t
a = (1 - self.betas).cumprod(dim=0)
x = x0 * a[total_noise_levels - 1].sqrt() + e * (1.0 - a[total_noise_levels - 1]).sqrt()
if bs_id < 2:
tvu.save_image((x + 1) * 0.5, os.path.join(out_dir, f'init_{it}.png'))
for i in reversed(range(total_noise_levels)):
t = torch.tensor([i] * batch_size, device=self.device)
x = self.diffusion.p_sample(self.model, x, t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None)["sample"]
# added intermediate step vis
if (i - 99) % 100 == 0 and bs_id < 2:
tvu.save_image((x + 1) * 0.5, os.path.join(out_dir, f'noise_t_{i}_{it}.png'))
x0 = x
if bs_id < 2:
torch.save(x0, os.path.join(out_dir, f'samples_{it}.pth'))
tvu.save_image((x0 + 1) * 0.5, os.path.join(out_dir, f'samples_{it}.png'))
xs.append(x0)
return torch.cat(xs, dim=0)