Skip to content

Commit

Permalink
release constrained generation
Browse files Browse the repository at this point in the history
  • Loading branch information
ghliu committed Dec 1, 2023
1 parent 8041d04 commit 140759b
Show file tree
Hide file tree
Showing 26 changed files with 5,066 additions and 0 deletions.
Binary file added assets/constr_gen.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/mdm-dual-afhqv2-8x8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/mdm-dual-ffhq-8x8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
92 changes: 92 additions & 0 deletions eval_constr_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import pickle
import argparse
from pathlib import Path
from collections import defaultdict
import numpy as np

import torch

from mdm import dataset
from mdm import constraintset
from mdm.runner import Runner
from mdm.logger import Logger
from mdm.metrices import compute_metrics

import colored_traceback.always
from ipdb import set_trace as debug

REF_DIR = Path("data")

def get_ref_x0(opt, log, data_sampler):
ref_fn = REF_DIR / f"{opt.p0}_{opt.constraint}_d{opt.xdim}_b{opt.batch_size}.pt"
if ref_fn.exists():
ref_x0 = torch.load(ref_fn, map_location="cpu")
log.info(f"Loaded ref points from {ref_fn}!")
else:
ref_fn.parent.mkdir(exist_ok=True)
ref_x0 = data_sampler.sample()
torch.save(ref_x0.cpu(), ref_fn)
log.info(f"Sampled and saved ref points to {ref_fn}!")
return ref_x0

def build_ckpt_option(opt, log, ckpt_path):
ckpt_path = Path(ckpt_path)
opt_pkl_path = ckpt_path / "options.pkl"
assert opt_pkl_path.exists()
with open(opt_pkl_path, "rb") as f:
ckpt_opt = pickle.load(f)
log.info(f"Loaded options from {opt_pkl_path=}!")

overwrite_keys = ["device", "batch_size"]
for k in overwrite_keys:
assert hasattr(opt, k)
setattr(ckpt_opt, k, getattr(opt, k))

if not hasattr(ckpt_opt, "noise_sched"): ckpt_opt.noise_sched = "linear"

ckpt_opt.load = ckpt_path / "latest.pt"
return ckpt_opt

@torch.no_grad()
def main(opt):
log = Logger(".log")

# restore ckpt
ckpt_opt = build_ckpt_option(opt, log, opt.ckpt_dir)
constraint = constraintset.build(ckpt_opt)
data_sampler = dataset.build(ckpt_opt, constraint)
run = Runner(ckpt_opt, log, save_opt=False)

# sample reference points
ref_x0 = get_ref_x0(ckpt_opt, log, data_sampler)
ref_x0 = ref_x0.to(opt.device)

# compute metrics
metrics = defaultdict(list)
for _ in range(opt.n_run):
# sample predict points
pred_x0, *_ = run.generate(ckpt_opt, constraint)

# compute metrices
pred_x0 = pred_x0.to(opt.device)
m_per_run = compute_metrics(pred_x0, ref_x0, constraint=constraint)
for k, v in m_per_run.items(): metrics[k].append(v)

for k, v in metrics.items():
vv = torch.stack(v)
log.info(f"{k}: {vv.mean().item():.4f}±{vv.std().item():.4f}")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-dir", type=Path, default=None)
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument("--gpu", type=int, default=None)
parser.add_argument("--n-run", type=int, default=3)
opt = parser.parse_args()

opt.device='cuda' if opt.gpu is None else f'cuda:{opt.gpu}'

main(opt)
11 changes: 11 additions & 0 deletions guided_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# ---------------------------------------------------------------
# Taken from the following link as is from:
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/__init__.py
#
# The license for the original version of this file can be
# found in this directory (LICENSE_GUIDED_DIFFUSION).
# ---------------------------------------------------------------

"""
Codebase for "Improved Denoising Diffusion Probabilistic Models".
"""
245 changes: 245 additions & 0 deletions guided_diffusion/fp16_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# ---------------------------------------------------------------
# Taken from the following link as is from:
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/fp16_util.py
#
# The license for the original version of this file can be
# found in this directory (LICENSE_GUIDED_DIFFUSION).
# ---------------------------------------------------------------

"""
Helpers to train with 16-bit precision.
"""

import numpy as np
import torch as th
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from . import logger

INITIAL_LOG_LOSS_SCALE = 20.0


def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()


def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.float()
if l.bias is not None:
l.bias.data = l.bias.data.float()


def make_master_params(param_groups_and_shapes):
"""
Copy model parameters into a (differently-shaped) list of full-precision
parameters.
"""
master_params = []
for param_group, shape in param_groups_and_shapes:
master_param = nn.Parameter(
_flatten_dense_tensors(
[param.detach().float() for (_, param) in param_group]
).view(shape)
)
master_param.requires_grad = True
master_params.append(master_param)
return master_params


def model_grads_to_master_grads(param_groups_and_shapes, master_params):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
for master_param, (param_group, shape) in zip(
master_params, param_groups_and_shapes
):
master_param.grad = _flatten_dense_tensors(
[param_grad_or_zeros(param) for (_, param) in param_group]
).view(shape)


def master_params_to_model_params(param_groups_and_shapes, master_params):
"""
Copy the master parameter data back into the model parameters.
"""
# Without copying to a list, if a generator is passed, this will
# silently not copy any parameters.
for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
for (_, param), unflat_master_param in zip(
param_group, unflatten_master_params(param_group, master_param.view(-1))
):
param.detach().copy_(unflat_master_param)


def unflatten_master_params(param_group, master_param):
return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])


def get_param_groups_and_shapes(named_model_params):
named_model_params = list(named_model_params)
scalar_vector_named_params = (
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
(-1),
)
matrix_named_params = (
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
(1, -1),
)
return [scalar_vector_named_params, matrix_named_params]


def master_params_to_state_dict(
model, param_groups_and_shapes, master_params, use_fp16
):
if use_fp16:
state_dict = model.state_dict()
for master_param, (param_group, _) in zip(
master_params, param_groups_and_shapes
):
for (name, _), unflat_master_param in zip(
param_group, unflatten_master_params(param_group, master_param.view(-1))
):
assert name in state_dict
state_dict[name] = unflat_master_param
else:
state_dict = model.state_dict()
for i, (name, _value) in enumerate(model.named_parameters()):
assert name in state_dict
state_dict[name] = master_params[i]
return state_dict


def state_dict_to_master_params(model, state_dict, use_fp16):
if use_fp16:
named_model_params = [
(name, state_dict[name]) for name, _ in model.named_parameters()
]
param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
master_params = make_master_params(param_groups_and_shapes)
else:
master_params = [state_dict[name] for name, _ in model.named_parameters()]
return master_params


def zero_master_grads(master_params):
for param in master_params:
param.grad = None


def zero_grad(model_params):
for param in model_params:
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()


def param_grad_or_zeros(param):
if param.grad is not None:
return param.grad.data.detach()
else:
return th.zeros_like(param)


class MixedPrecisionTrainer:
def __init__(
self,
*,
model,
use_fp16=False,
fp16_scale_growth=1e-3,
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
):
self.model = model
self.use_fp16 = use_fp16
self.fp16_scale_growth = fp16_scale_growth

self.model_params = list(self.model.parameters())
self.master_params = self.model_params
self.param_groups_and_shapes = None
self.lg_loss_scale = initial_lg_loss_scale

if self.use_fp16:
self.param_groups_and_shapes = get_param_groups_and_shapes(
self.model.named_parameters()
)
self.master_params = make_master_params(self.param_groups_and_shapes)
self.model.convert_to_fp16()

def zero_grad(self):
zero_grad(self.model_params)

def backward(self, loss: th.Tensor):
if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
(loss * loss_scale).backward()
else:
loss.backward()

def optimize(self, opt: th.optim.Optimizer):
if self.use_fp16:
return self._optimize_fp16(opt)
else:
return self._optimize_normal(opt)

def _optimize_fp16(self, opt: th.optim.Optimizer):
logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
if check_overflow(grad_norm):
self.lg_loss_scale -= 1
logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
zero_master_grads(self.master_params)
return False

logger.logkv_mean("grad_norm", grad_norm)
logger.logkv_mean("param_norm", param_norm)

for p in self.master_params:
p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
opt.step()
zero_master_grads(self.master_params)
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
self.lg_loss_scale += self.fp16_scale_growth
return True

def _optimize_normal(self, opt: th.optim.Optimizer):
grad_norm, param_norm = self._compute_norms()
logger.logkv_mean("grad_norm", grad_norm)
logger.logkv_mean("param_norm", param_norm)
opt.step()
return True

def _compute_norms(self, grad_scale=1.0):
grad_norm = 0.0
param_norm = 0.0
for p in self.master_params:
with th.no_grad():
param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
if p.grad is not None:
grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)

def master_params_to_state_dict(self, master_params):
return master_params_to_state_dict(
self.model, self.param_groups_and_shapes, master_params, self.use_fp16
)

def state_dict_to_master_params(self, state_dict):
return state_dict_to_master_params(self.model, state_dict, self.use_fp16)


def check_overflow(value):
return (value == float("inf")) or (value == -float("inf")) or (value != value)
Loading

0 comments on commit 140759b

Please sign in to comment.