diff --git a/edm/torch_utils/misc.py b/edm/torch_utils/misc.py index f0d3184..fe564aa 100644 --- a/edm/torch_utils/misc.py +++ b/edm/torch_utils/misc.py @@ -10,7 +10,7 @@ import numpy as np import torch import warnings -import dnnlib +from edm import dnnlib #---------------------------------------------------------------------------- # Cached construction of constant tensors. Avoids CPU=>GPU copy when the diff --git a/edm/torch_utils/persistence.py b/edm/torch_utils/persistence.py index fbecbe2..a924e48 100644 --- a/edm/torch_utils/persistence.py +++ b/edm/torch_utils/persistence.py @@ -19,7 +19,7 @@ import copy import uuid import types -import dnnlib +from edm import dnnlib #---------------------------------------------------------------------------- diff --git a/edm/torch_utils/training_stats.py b/edm/torch_utils/training_stats.py index 727c4e8..8e11b10 100644 --- a/edm/torch_utils/training_stats.py +++ b/edm/torch_utils/training_stats.py @@ -13,7 +13,7 @@ import re import numpy as np import torch -import dnnlib +from edm import dnnlib from . import misc diff --git a/edm/training/augment.py b/edm/training/augment.py index a8d474d..7c4bf9e 100644 --- a/edm/training/augment.py +++ b/edm/training/augment.py @@ -12,8 +12,8 @@ import numpy as np import torch -from torch_utils import persistence -from torch_utils import misc +from edm.torch_utils import persistence +from edm.torch_utils import misc #---------------------------------------------------------------------------- # Coefficients of various wavelet decomposition low-pass filters. diff --git a/edm/training/dataset.py b/edm/training/dataset.py index ef4bd02..54544b8 100644 --- a/edm/training/dataset.py +++ b/edm/training/dataset.py @@ -13,7 +13,7 @@ import PIL.Image import json import torch -import dnnlib +from edm import dnnlib try: import pyspng diff --git a/edm/training/loss.py b/edm/training/loss.py index ff045c5..84d5bdf 100644 --- a/edm/training/loss.py +++ b/edm/training/loss.py @@ -9,7 +9,7 @@ "Elucidating the Design Space of Diffusion-Based Generative Models".""" import torch -from torch_utils import persistence +from edm.torch_utils import persistence #---------------------------------------------------------------------------- # Loss function corresponding to the variance preserving (VP) formulation diff --git a/edm/training/networks.py b/edm/training/networks.py index d2326c7..868d677 100644 --- a/edm/training/networks.py +++ b/edm/training/networks.py @@ -10,7 +10,7 @@ import numpy as np import torch -from torch_utils import persistence +from edm.torch_utils import persistence from torch.nn.functional import silu #---------------------------------------------------------------------------- diff --git a/edm/training/training_loop.py b/edm/training/training_loop.py index 109d7d2..e50222b 100644 --- a/edm/training/training_loop.py +++ b/edm/training/training_loop.py @@ -15,10 +15,10 @@ import psutil import numpy as np import torch -import dnnlib -from torch_utils import distributed as dist -from torch_utils import training_stats -from torch_utils import misc +from edm import dnnlib +from edm.torch_utils import distributed as dist +from edm.torch_utils import training_stats +from edm.torch_utils import misc #---------------------------------------------------------------------------- @@ -46,6 +46,8 @@ def training_loop( resume_kimg = 0, # Start from the given training progress. cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? device = torch.device('cuda'), + pretrain = None, # (MDM) Pretrain ckpt for initialization + watermark_kwargs = None, # (MDM) Options for mdm.watermark.ImageWatermark. ): # Initialize. start_time = time.time() @@ -69,10 +71,19 @@ def training_loop( dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed) dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs)) + # Construct dual space mapping + watermark = dnnlib.util.construct_class_by_name(**watermark_kwargs) + # Construct network. dist.print0('Constructing network...') interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim) net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module + + if pretrain: + dist.print0(f'Loading pratrain weights from {pretrain}...') + net.load_state_dict(torch.load(pretrain, map_location="cpu")) + net = net.to(device) + net.train().requires_grad_(True).to(device) if dist.get_rank() == 0: with torch.no_grad(): @@ -126,6 +137,9 @@ def training_loop( with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): images, labels = next(dataset_iterator) images = images.to(device).to(torch.float32) / 127.5 - 1 + # images output: [-1.1] + images = watermark.to_dual(images) + labels = labels.to(device) loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) training_stats.report('Loss/loss', loss) @@ -176,7 +190,13 @@ def training_loop( # Save network snapshot. if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0): - data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs)) + data = dict( + ema=ema, + loss_fn=loss_fn, + augment_pipe=augment_pipe, + dataset_kwargs=dict(dataset_kwargs), + watermark_kwargs=dict(watermark_kwargs), + ) for key, value in data.items(): if isinstance(value, torch.nn.Module): value = copy.deepcopy(value).eval().requires_grad_(False) diff --git a/edm/generate.py b/generate_watermark.py similarity index 96% rename from edm/generate.py rename to generate_watermark.py index 45c1241..3420a79 100644 --- a/edm/generate.py +++ b/generate_watermark.py @@ -16,8 +16,8 @@ import numpy as np import torch import PIL.Image -import dnnlib -from torch_utils import distributed as dist +from edm import dnnlib +from edm.torch_utils import distributed as dist #---------------------------------------------------------------------------- # Proposed EDM sampler (Algorithm 2). @@ -265,6 +265,12 @@ def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device= with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: net = pickle.load(f)['ema'].to(device) + # Construct mirror space mapping + with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: + watermark_kwargs = pickle.load(f)['watermark_kwargs'] + watermark_kwargs.update(device=device) + watermark = dnnlib.util.construct_class_by_name(**watermark_kwargs) + # Other ranks follow. if dist.get_rank() == 0: torch.distributed.barrier() @@ -292,6 +298,9 @@ def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device= have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling']) sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs) + images = watermark.to_primal(images.float()) # output: [-1,1] + if not watermark.is_feasible(images).all(): + print(f"{batch_seeds=}", watermark.is_feasible(images)) # Save images. images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy() diff --git a/edm/train.py b/train_watermark.py similarity index 91% rename from edm/train.py rename to train_watermark.py index 6851604..4fa80b8 100644 --- a/edm/train.py +++ b/train_watermark.py @@ -13,9 +13,9 @@ import json import click import torch -import dnnlib -from torch_utils import distributed as dist -from training import training_loop +from edm import dnnlib +from edm.torch_utils import distributed as dist +from edm.training import training_loop import warnings warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. @@ -76,6 +76,7 @@ def parse_int_list(s): @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str) @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) +@click.option('--pretrain', help='Pretrain weights', metavar='PT', type=str) def main(**kwargs): """Train diffusion-based generative model using the techniques described in the @@ -94,11 +95,12 @@ def main(**kwargs): # Initialize config dict. c = dnnlib.EasyDict() - c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache) + c.dataset_kwargs = dnnlib.EasyDict(class_name='edm.training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache) c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2) c.network_kwargs = dnnlib.EasyDict() c.loss_kwargs = dnnlib.EasyDict() c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8) + c.watermark_kwargs = dnnlib.EasyDict(class_name='mdm.watermark.ImageWatermark', dataset_name=opts.data[5: -10]) # Validate dataset options. try: @@ -125,15 +127,15 @@ def main(**kwargs): # Preconditioning & loss function. if opts.precond == 'vp': - c.network_kwargs.class_name = 'training.networks.VPPrecond' - c.loss_kwargs.class_name = 'training.loss.VPLoss' + c.network_kwargs.class_name = 'edm.training.networks.VPPrecond' + c.loss_kwargs.class_name = 'edm.training.loss.VPLoss' elif opts.precond == 've': - c.network_kwargs.class_name = 'training.networks.VEPrecond' - c.loss_kwargs.class_name = 'training.loss.VELoss' + c.network_kwargs.class_name = 'edm.training.networks.VEPrecond' + c.loss_kwargs.class_name = 'edm.training.loss.VELoss' else: assert opts.precond == 'edm' - c.network_kwargs.class_name = 'training.networks.EDMPrecond' - c.loss_kwargs.class_name = 'training.loss.EDMLoss' + c.network_kwargs.class_name = 'edm.training.networks.EDMPrecond' + c.loss_kwargs.class_name = 'edm.training.loss.EDMLoss' # Network options. if opts.cbase is not None: @@ -141,7 +143,7 @@ def main(**kwargs): if opts.cres is not None: c.network_kwargs.channel_mult = opts.cres if opts.augment: - c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment) + c.augment_kwargs = dnnlib.EasyDict(class_name='edm.training.augment.AugmentPipe', p=opts.augment) c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1) c.network_kwargs.augment_dim = 9 c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16) @@ -174,6 +176,8 @@ def main(**kwargs): c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl') c.resume_kimg = int(match.group(1)) c.resume_state_dump = opts.resume + if opts.pretrain: + c.pretrain = opts.pretrain # Description string. cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'