forked from NVlabs/LSGM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_vada.py
345 lines (308 loc) · 18.8 KB
/
evaluate_vada.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for LSGM. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
import argparse
import torch
from torchvision.utils import save_image
import numpy as np
import os
import matplotlib.pyplot as plt
from time import time
from torch.multiprocessing import Process
from nvae import NVAE
from diffusion_discretized import DiffusionDiscretized
from diffusion_continuous import make_diffusion
try:
from apex.optimizers import FusedAdam
except ImportError:
print("No Apex Available. Using PyTorch's native Adam. Install Apex for faster training.")
from torch.optim import Adam as FusedAdam
from util import utils, datasets
from util.ema import EMA
from evaluate_diffusion import test_dae_fid, generate_samples_vada, elbo_evaluation
def main(eval_args):
# common initialization
logging, writer = utils.common_init(eval_args.global_rank, eval_args.seed, eval_args.save)
# load a checkpoint
logging.info('#' * 80)
logging.info('loading the model at:')
logging.info(eval_args.checkpoint)
checkpoint = torch.load(eval_args.checkpoint, map_location='cpu')
args = checkpoint['args']
# adding some arguments for backward compatibility.
if not hasattr(args, 'num_x_bits'):
logging.info('*** Setting %s manually ****', 'num_x_bits')
setattr(args, 'num_x_bits', 8)
if not hasattr(args, 'channel_mult'):
logging.info('*** Setting %s manually ****', 'channel_mult')
setattr(args, 'channel_mult', [1, 2])
if not hasattr(args, 'mixing_logit_init'):
logging.info('*** Setting %s manually ****', 'mixing_logit_init')
setattr(args, 'mixing_logit_init', -3.0)
if eval_args.diffusion_steps > 0:
args.diffusion_steps = eval_args.diffusion_steps
epoch = checkpoint['epoch']
logging.info('loaded the model at epoch %d', checkpoint['epoch'])
arch_instance_nvae = utils.get_arch_cells(args.arch_instance, args.use_se)
logging.info('args = %s', args)
logging.info('evalargs = %s', eval_args)
# load VAE
vae = NVAE(args, arch_instance_nvae)
vae.load_state_dict(checkpoint['vae_state_dict'])
vae = vae.cuda()
logging.info('VAE: param size = %fM ', utils.count_parameters_in_M(vae))
# load DAE
num_input_channels = vae.latent_structure()[0]
dae = utils.get_dae_model(args, num_input_channels)
dae.load_state_dict(checkpoint['dae_state_dict'])
diffusion_cont = make_diffusion(args)
diffusion_disc = DiffusionDiscretized(args, diffusion_cont.var)
logging.info('DAE: param size = %fM ', utils.count_parameters_in_M(dae))
checkpoint_name = os.path.basename(eval_args.checkpoint)
if checkpoint_name == 'checkpoint.pt':
logging.info('Swapping the parameters of DAE with EMA parameters')
# checkpoint.pt models require swapping EMA parameters
dae_optimizer = FusedAdam(dae.parameters(), args.learning_rate_dae,
weight_decay=args.weight_decay, eps=1e-4)
# add EMA functionality to the optimizer
dae_optimizer = EMA(dae_optimizer, ema_decay=args.ema_decay)
dae_optimizer.load_state_dict(checkpoint['dae_optimizer'])
# replace DAE parameters with EMA values
dae_optimizer.swap_parameters_with_ema(store_params_in_ema=True)
elif checkpoint_name in {'checkpoint_fid.pt', 'checkpoint_nll.pt', 'checkpoint_finetuned.pt'}:
logging.info('swapping the parameters of DAE with EMA parameters is ** not ** required.')
else:
raise ValueError('Cannot recognize checkpoint name %s' % checkpoint_name)
dae = dae.cuda()
# set the model to eval() model.
dae.eval()
# set vae to train mode if the arg says
vae.train(mode=eval_args.vae_train_mode)
if eval_args.eval_mode == 'evaluate':
# replace a few fields in args based on eval_args
# this will allow train/evaluate on different systems
args.num_proc_node = eval_args.num_proc_node
args.num_process_per_node = eval_args.num_process_per_node
args.data = eval_args.data
ode_param_dict = {'ode_solver_tol': eval_args.ode_solver_tol, 'ode_eps': eval_args.ode_eps}
if eval_args.batch_size > 0:
args.batch_size = eval_args.batch_size
# load train valid queue
bpd_coeff = utils.get_bpd_coeff(args.dataset)
train_queue, valid_queue, num_classes = datasets.get_loaders(args)
if eval_args.eval_on_train:
logging.info('Using the training data for eval.')
valid_queue = train_queue
# evaluate NLL ODE-style
if eval_args.nll_ode_eval:
args.ode_eps = eval_args.ode_eps
args.ode_solver_tol = eval_args.ode_solver_tol
logging.info('Running ODE-based NLL evaluation...')
nelbo_ode, neg_log_p_ode, nfe_nll_ode, stddev, stderror = elbo_evaluation(
valid_queue, diffusion_cont, dae, args, vae, max_step=None, ode_eval=True, ode_param_dict=ode_param_dict,
num_samples=eval_args.num_iw_samples, num_inner_samples=eval_args.num_iw_inner_samples, report_std=True)
logging.info('valid ODE-based NELBO: {} bpd (with {} func. evals on average, ode error tol. {}, ode eps {})'
.format(nelbo_ode * bpd_coeff, nfe_nll_ode, ode_param_dict['ode_solver_tol'], ode_param_dict['ode_eps']))
logging.info('valid ODE-based neg. log p: {} bpd'.format(neg_log_p_ode * bpd_coeff))
if stddev is not None and stderror is not None:
logging.info('valid ODE-based prior prob std. deviation: {} bpd ({} nats)'.format(stddev * bpd_coeff, stddev))
logging.info('valid ODE-based prior prob std. error: {} bpd ({} nats)'.format(stderror * bpd_coeff, stderror))
logging.info('valid ODE-based prior prob variance (based on std. deviation): {} bpd ({} nats)'.format(stddev * stddev * bpd_coeff, stddev * stddev))
logging.info('valid ODE-based prior prob variance (based on std. error): {} bpd ({} nats)'.format(stderror * stderror * bpd_coeff, stderror * stderror))
writer.add_scalar('val/final_ode_nelbo_bpd', nelbo_ode * bpd_coeff, 1)
writer.add_scalar('val/final_ode_neg_log_p_bpd', neg_log_p_ode * bpd_coeff, 1)
writer.add_scalar('val/final_ode_nfe', nfe_nll_ode, 1)
# evaluate ELBO with regular fixed discretization
if eval_args.elbo_eval:
logging.info('Running regular/discretized ELBO evaluation...')
val_elbo, val_log_p, _, _, _ = elbo_evaluation(valid_queue, diffusion_disc, dae, args, vae, max_step=None,
ode_eval=False, num_samples=eval_args.num_iw_samples)
logging.info('valid NELBO: {} bpd'.format(val_elbo * bpd_coeff))
logging.info('valid neg log P {} bpd %f'.format(val_log_p * bpd_coeff))
writer.add_scalar('val/final_nelbo_bpd', val_elbo * bpd_coeff, 1)
writer.add_scalar('val/final_neg_log_p_bpd', val_log_p * bpd_coeff, 1)
# compute FID using samples from ODE-based sampling
if eval_args.fid_ode_eval:
args.fid_dir = eval_args.fid_dir
num_fid_samples = eval_args.num_fid_samples
logging.info('Running ODE-based FID evaluation...')
val_fid_ema_ode = test_dae_fid(dae, diffusion_cont, writer, logging, args, num_fid_samples, vae,
ode_param_dict, eval_args.temp, eval_args.vae_temp)
logging.info('valid ODE-based FID: {} (ode error tol. {}, ode eps {})'
.format(val_fid_ema_ode, ode_param_dict['ode_solver_tol'], ode_param_dict['ode_eps']))
writer.add_scalar('val/final_fid_ode_samples', val_fid_ema_ode, 1)
# compute FID using samples from regular fixed discretization sampling
if eval_args.fid_disc_eval:
args.fid_dir = eval_args.fid_dir
num_fid_samples = eval_args.num_fid_samples
logging.info('Running regular/discretized FID evaluation...')
val_fid_ema = test_dae_fid(dae, diffusion_disc, writer, logging, args, num_fid_samples, vae,
temp=eval_args.temp, vae_temp=eval_args.vae_temp)
logging.info('valid FID: {}'.format(val_fid_ema))
writer.add_scalar('val/final_fid', val_fid_ema, 1)
# compute average number of function evaluation for ODE-based sampling from the model
if eval_args.nfe_eval:
nfe_sum = 0.0
odetime_sum = 0.0
sampling_time_sum = 0.0
num_iter = 50
logging.info('Starting to sample with ODE framework for average NFE calculation...')
for i in range(num_iter):
# Note that this is just a quick hack... This sampling will be done 50 times on
_, nfe, odetime, sampling_time = generate_samples_vada(dae, diffusion_cont, vae, args.batch_size,
enable_autocast=args.autocast_eval, ode_eps=eval_args.ode_eps,
ode_solver_tol=eval_args.ode_solver_tol, ode_sample=True,
prior_var=args.sigma2_max if args.sde_type == 'vesde' else 1.0)
nfe_sum = nfe_sum + nfe
odetime_sum = odetime_sum + odetime
sampling_time_sum = sampling_time_sum + sampling_time
logging.info('Sampling NFE (average over {} batches with batchsize {}): {}'
.format(num_iter * args.num_process_per_node * args.num_proc_node, args.batch_size, nfe_sum / num_iter))
logging.info('Sampling Time (average over {} batches with batchsize {}): {} seconds'
.format(num_iter * args.num_process_per_node * args.num_proc_node, args.batch_size, sampling_time_sum / num_iter))
logging.info('ODE Solve Time only (average over {} batches with batchsize {}): {} seconds'
.format(num_iter * args.num_process_per_node * args.num_proc_node, args.batch_size, odetime_sum / num_iter))
writer.add_scalar('ode_sampling_nfe/final_average', nfe_sum / num_iter, 1)
writer.close()
elif eval_args.eval_mode == 'sample':
num_total_samples = eval_args.num_fid_samples
num_gpus = args.num_process_per_node * args.num_proc_node
num_sample_per_gpu = int(np.ceil(num_total_samples / num_gpus))
num_samples = 100
n = 10
m = 10
# n = int(np.floor(np.sqrt(num_samples)))
num_iter = int(np.ceil(num_sample_per_gpu / num_samples))
all_nfe = []
for i in range(num_iter):
if i == 1:
start = time()
if eval_args.ode_sampling:
logging.info('Starting to sample with ODE framework...')
logging.info('ODE params: ODE eps %f, ODE tol %f', eval_args.ode_eps, eval_args.ode_solver_tol)
samples, nfe, odetime, sampling_time = generate_samples_vada(
dae, diffusion_cont, vae, num_samples, enable_autocast=args.autocast_eval,
ode_eps=eval_args.ode_eps, ode_solver_tol=eval_args.ode_solver_tol, ode_sample=True,
prior_var=args.sigma2_max if args.sde_type == 'vesde' else 1.0, temp=eval_args.temp,
vae_temp=eval_args.vae_temp)
else:
logging.info('Starting to sample with naive discretization...')
samples, nfe, odetime, sampling_time = generate_samples_vada(
dae, diffusion_disc, vae, num_samples, enable_autocast=args.autocast_eval,
prior_var=args.sigma2_max if args.sde_type == 'vesde' else 1.0, temp=eval_args.temp,
vae_temp=eval_args.vae_temp)
logging.info('Sampled new images ({} func. evals, {} seconds for ODE solve, {} seconds for full sampling) of latent space diffusion model.'.format(nfe, odetime, sampling_time))
all_nfe.append(nfe.cpu().numpy())
visualize = False
output_tiled = utils.tile_image(samples, n, m)
if visualize:
plt.figure(figsize=(12, 12))
img = output_tiled.permute(1, 2, 0).float().cpu().numpy()
if img.shape[2] == 1:
img = np.squeeze(img)
plt.imshow(img, cmap=plt.get_cmap('gray'))
else:
plt.imshow(img)
plt.show()
else:
# save tiled image
# save_image(output_tiled, eval_args.save + '/samples.png')
# logging.info('Saved at: {}'.format(eval_args.save + '/samples.png'))
file_path = os.path.join(eval_args.save, 'gpu_%d_samples_%d.npz' % (eval_args.global_rank, i))
np.savez_compressed(file_path, samples=samples.cpu().numpy())
logging.info('Saved at: {}'.format(file_path))
if num_iter > 1:
end = time() - start
logging.info('timing %f, avg NFE %f' % (end / (num_iter - 1), np.mean(all_nfe)))
if __name__ == '__main__':
parser = argparse.ArgumentParser('parser')
# directories for experiment results and checkpoint
parser.add_argument('--checkpoint', type=str, default='/path/to/checkpoint.pt',
help='location of the checkpoint')
parser.add_argument('--root', type=str, default='/tmp/nvae-diff/expr',
help='location of the results')
parser.add_argument('--save', type=str, default='debug_ode',
help='id used for storing intermediate results')
parser.add_argument('--eval_mode', type=str, default='evaluate', choices=['sample', 'evaluate'],
help='evaluation mode. you can choose between sample or evaluate.')
parser.add_argument('--eval_on_train', action='store_true', default=False,
help='Settings this to true will evaluate the model on training data.')
parser.add_argument('--data', type=str, default='/tmp/data',
help='location of the data corpus')
parser.add_argument('--fid_dir', type=str, default='/tmp/nvae-diff/fid-stats',
help='path to directory where fid related files are stored')
parser.add_argument('--readjust_bn', action='store_true', default=False,
help='adding this flag will enable readjusting BN statistics.')
parser.add_argument('--temp', type=float, default=1.0,
help='The temperature used for sampling.')
parser.add_argument('--vae_temp', type=float, default=1.0,
help='The temperature used for sampling in vae.')
parser.add_argument('--vae_train_mode', action='store_true', default=False,
help='evaluate vae in train mode, suitable for BN experiments.')
parser.add_argument('--num_iw_samples', type=int, default=1,
help='The number of samples from latent space used in IW evaluation.')
parser.add_argument('--num_iw_inner_samples', type=int, default=1,
help='How often we solve the ODE and average when calculating prior probability.')
parser.add_argument('--num_fid_samples', type=int, default=50000,
help='The number of samples used for FID computation.')
parser.add_argument('--batch_size', type=int, default=0,
help='Batch size used during evaluation. If set to zero, training batch size is used.')
parser.add_argument('--elbo_eval', action='store_true', default=False,
help='if True, we perform discrete ELBO evaluation.')
parser.add_argument('--fid_disc_eval', action='store_true', default=False,
help='if True, we perform FID evaluation.')
parser.add_argument('--fid_ode_eval', action='store_true', default=False,
help='if True, we perform FID evaluation using ODE-based model samples.')
parser.add_argument('--nll_ode_eval', action='store_true', default=False,
help='if True, we perform ODE-based NLL evaluation.')
parser.add_argument('--nfe_eval', action='store_true', default=False,
help='if True, we sample 50 batches of images and average NFEs.')
parser.add_argument('--ode_sampling', action='store_true', default=False,
help='if True, do ODE-based sampling, otherwise regular sampling. Only relevant when sampling.')
parser.add_argument('--ode_eps', type=float, default=0.00001,
help='ODE can only be integrated up to some epsilon > 0.')
parser.add_argument('--ode_solver_tol', type=float, default=1e-5,
help='ODE solver error tolerance.')
parser.add_argument('--diffusion_steps', type=int, default=0,
help='number of diffusion steps')
# DDP.
parser.add_argument('--seed', type=int, default=1,
help='seed used for initialization')
parser.add_argument('--master_address', type=str, default='127.0.0.1',
help='address for master')
parser.add_argument('--num_proc_node', type=int, default=1,
help='The number of nodes in multi node env.')
parser.add_argument('--node_rank', type=int, default=0,
help='The index of node.')
parser.add_argument('--local_rank', type=int, default=0,
help='rank of process in the node')
parser.add_argument('--global_rank', type=int, default=0,
help='rank of process among all the processes')
parser.add_argument('--num_process_per_node', type=int, default=1,
help='number of gpus')
args = parser.parse_args()
args.save = args.root + '/' + args.save
utils.create_exp_dir(args.save)
size = args.num_process_per_node
if size > 1:
args.distributed = True
processes = []
for rank in range(size):
args.local_rank = rank
global_rank = rank + args.node_rank * args.num_process_per_node
global_size = args.num_proc_node * args.num_process_per_node
args.global_rank = global_rank
print('Node rank %d, local proc %d, global proc %d' % (args.node_rank, rank, global_rank))
p = Process(target=utils.init_processes, args=(global_rank, global_size, main, args))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
# for debugging
print('starting in debug mode')
args.distributed = True
utils.init_processes(0, size, main, args)