Skip to content

Commit

Permalink
integrate edm with mirror map
Browse files Browse the repository at this point in the history
  • Loading branch information
ghliu committed Dec 1, 2023
1 parent 4150c54 commit 7d23acb
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 26 deletions.
2 changes: 1 addition & 1 deletion edm/torch_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import torch
import warnings
import dnnlib
from edm import dnnlib

#----------------------------------------------------------------------------
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
Expand Down
2 changes: 1 addition & 1 deletion edm/torch_utils/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import copy
import uuid
import types
import dnnlib
from edm import dnnlib

#----------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion edm/torch_utils/training_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import re
import numpy as np
import torch
import dnnlib
from edm import dnnlib

from . import misc

Expand Down
4 changes: 2 additions & 2 deletions edm/training/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

import numpy as np
import torch
from torch_utils import persistence
from torch_utils import misc
from edm.torch_utils import persistence
from edm.torch_utils import misc

#----------------------------------------------------------------------------
# Coefficients of various wavelet decomposition low-pass filters.
Expand Down
2 changes: 1 addition & 1 deletion edm/training/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import PIL.Image
import json
import torch
import dnnlib
from edm import dnnlib

try:
import pyspng
Expand Down
2 changes: 1 addition & 1 deletion edm/training/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"Elucidating the Design Space of Diffusion-Based Generative Models"."""

import torch
from torch_utils import persistence
from edm.torch_utils import persistence

#----------------------------------------------------------------------------
# Loss function corresponding to the variance preserving (VP) formulation
Expand Down
2 changes: 1 addition & 1 deletion edm/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import torch
from torch_utils import persistence
from edm.torch_utils import persistence
from torch.nn.functional import silu

#----------------------------------------------------------------------------
Expand Down
30 changes: 25 additions & 5 deletions edm/training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import psutil
import numpy as np
import torch
import dnnlib
from torch_utils import distributed as dist
from torch_utils import training_stats
from torch_utils import misc
from edm import dnnlib
from edm.torch_utils import distributed as dist
from edm.torch_utils import training_stats
from edm.torch_utils import misc

#----------------------------------------------------------------------------

Expand Down Expand Up @@ -46,6 +46,8 @@ def training_loop(
resume_kimg = 0, # Start from the given training progress.
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
device = torch.device('cuda'),
pretrain = None, # (MDM) Pretrain ckpt for initialization
watermark_kwargs = None, # (MDM) Options for mdm.watermark.ImageWatermark.
):
# Initialize.
start_time = time.time()
Expand All @@ -69,10 +71,19 @@ def training_loop(
dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))

# Construct dual space mapping
watermark = dnnlib.util.construct_class_by_name(**watermark_kwargs)

# Construct network.
dist.print0('Constructing network...')
interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module

if pretrain:
dist.print0(f'Loading pratrain weights from {pretrain}...')
net.load_state_dict(torch.load(pretrain, map_location="cpu"))
net = net.to(device)

net.train().requires_grad_(True).to(device)
if dist.get_rank() == 0:
with torch.no_grad():
Expand Down Expand Up @@ -126,6 +137,9 @@ def training_loop(
with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
images, labels = next(dataset_iterator)
images = images.to(device).to(torch.float32) / 127.5 - 1
# images output: [-1.1]
images = watermark.to_dual(images)

labels = labels.to(device)
loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe)
training_stats.report('Loss/loss', loss)
Expand Down Expand Up @@ -176,7 +190,13 @@ def training_loop(

# Save network snapshot.
if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0):
data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs))
data = dict(
ema=ema,
loss_fn=loss_fn,
augment_pipe=augment_pipe,
dataset_kwargs=dict(dataset_kwargs),
watermark_kwargs=dict(watermark_kwargs),
)
for key, value in data.items():
if isinstance(value, torch.nn.Module):
value = copy.deepcopy(value).eval().requires_grad_(False)
Expand Down
13 changes: 11 additions & 2 deletions edm/generate.py → generate_watermark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import numpy as np
import torch
import PIL.Image
import dnnlib
from torch_utils import distributed as dist
from edm import dnnlib
from edm.torch_utils import distributed as dist

#----------------------------------------------------------------------------
# Proposed EDM sampler (Algorithm 2).
Expand Down Expand Up @@ -265,6 +265,12 @@ def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=
with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
net = pickle.load(f)['ema'].to(device)

# Construct mirror space mapping
with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
watermark_kwargs = pickle.load(f)['watermark_kwargs']
watermark_kwargs.update(device=device)
watermark = dnnlib.util.construct_class_by_name(**watermark_kwargs)

# Other ranks follow.
if dist.get_rank() == 0:
torch.distributed.barrier()
Expand Down Expand Up @@ -292,6 +298,9 @@ def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=
have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs)
images = watermark.to_primal(images.float()) # output: [-1,1]
if not watermark.is_feasible(images).all():
print(f"{batch_seeds=}", watermark.is_feasible(images))

# Save images.
images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
Expand Down
26 changes: 15 additions & 11 deletions edm/train.py → train_watermark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import json
import click
import torch
import dnnlib
from torch_utils import distributed as dist
from training import training_loop
from edm import dnnlib
from edm.torch_utils import distributed as dist
from edm.training import training_loop

import warnings
warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12.
Expand Down Expand Up @@ -76,6 +76,7 @@ def parse_int_list(s):
@click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str)
@click.option('--resume', help='Resume from previous training state', metavar='PT', type=str)
@click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
@click.option('--pretrain', help='Pretrain weights', metavar='PT', type=str)

def main(**kwargs):
"""Train diffusion-based generative model using the techniques described in the
Expand All @@ -94,11 +95,12 @@ def main(**kwargs):

# Initialize config dict.
c = dnnlib.EasyDict()
c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache)
c.dataset_kwargs = dnnlib.EasyDict(class_name='edm.training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache)
c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
c.network_kwargs = dnnlib.EasyDict()
c.loss_kwargs = dnnlib.EasyDict()
c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8)
c.watermark_kwargs = dnnlib.EasyDict(class_name='mdm.watermark.ImageWatermark', dataset_name=opts.data[5: -10])

# Validate dataset options.
try:
Expand All @@ -125,23 +127,23 @@ def main(**kwargs):

# Preconditioning & loss function.
if opts.precond == 'vp':
c.network_kwargs.class_name = 'training.networks.VPPrecond'
c.loss_kwargs.class_name = 'training.loss.VPLoss'
c.network_kwargs.class_name = 'edm.training.networks.VPPrecond'
c.loss_kwargs.class_name = 'edm.training.loss.VPLoss'
elif opts.precond == 've':
c.network_kwargs.class_name = 'training.networks.VEPrecond'
c.loss_kwargs.class_name = 'training.loss.VELoss'
c.network_kwargs.class_name = 'edm.training.networks.VEPrecond'
c.loss_kwargs.class_name = 'edm.training.loss.VELoss'
else:
assert opts.precond == 'edm'
c.network_kwargs.class_name = 'training.networks.EDMPrecond'
c.loss_kwargs.class_name = 'training.loss.EDMLoss'
c.network_kwargs.class_name = 'edm.training.networks.EDMPrecond'
c.loss_kwargs.class_name = 'edm.training.loss.EDMLoss'

# Network options.
if opts.cbase is not None:
c.network_kwargs.model_channels = opts.cbase
if opts.cres is not None:
c.network_kwargs.channel_mult = opts.cres
if opts.augment:
c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment)
c.augment_kwargs = dnnlib.EasyDict(class_name='edm.training.augment.AugmentPipe', p=opts.augment)
c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
c.network_kwargs.augment_dim = 9
c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16)
Expand Down Expand Up @@ -174,6 +176,8 @@ def main(**kwargs):
c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl')
c.resume_kimg = int(match.group(1))
c.resume_state_dump = opts.resume
if opts.pretrain:
c.pretrain = opts.pretrain

# Description string.
cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'
Expand Down

0 comments on commit 7d23acb

Please sign in to comment.