-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #99 from corochann/dcgan_example
Dcgan example
- Loading branch information
Showing
6 changed files
with
376 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# DCGAN | ||
|
||
This is an example implementation of DCGAN (https://arxiv.org/abs/1511.06434) | ||
trained on multi-GPU using `chainermn`. | ||
|
||
This code uses Cifar-10 dataset by default. | ||
You can use your own dataset by specifying `--dataset` argument to the directory consisting of image files for training. | ||
The model assumes the resolution of an input image is 32x32. | ||
If you want to use another image resolution, you need to change the network architecture in net.py. | ||
|
||
Below is an example learning result using cifar-10 dataset after 200 epoch, | ||
where the model is trained using 4 GPUs with minibatch size 50 for each process. | ||
|
||
 | ||
|
||
## Implementation | ||
|
||
The original implementation is referenced from [chainer examples](https://github.com/chainer/chainer/tree/79d6bf6f4f5c86ba705b8fd377368519bc1fd264/examples/dcgan). | ||
|
||
It is worth noting that only main training code, `train_dcgan.py`, is modified from original code. | ||
The model definition code in `net.py`, the updater code (which defines how to calculate the loss to train generator and discriminator) in `updater.py`, | ||
and the training extension code in `visualize.py` are completely same with the original code. | ||
|
||
We can reuse most of the code developed in `chainer`, to support multi-GPU training with `chainermn`. | ||
|
||
## How to run the code | ||
|
||
For example, below command is to train the model using 4 GPUs (= processes). | ||
|
||
``` | ||
mpiexec -n 4 python train_dcgan.py -g | ||
``` | ||
|
||
If you want to restart the training to fine tune the trained model, | ||
specify the file path saved by `snapshot_object`. | ||
Below command loads the models which are trained 30000 iterations. | ||
``` | ||
mpiexec -n 4 python train_dcgan.py -g --gen_model=result/gen_iter_30000.npz --dis_model=result/dis_iter_30000.npz --out=result_finetune | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#!/usr/bin/env python | ||
|
||
from __future__ import print_function | ||
|
||
import numpy | ||
|
||
import chainer | ||
from chainer import cuda | ||
import chainer.functions as F | ||
import chainer.links as L | ||
|
||
|
||
def add_noise(h, sigma=0.2): | ||
xp = cuda.get_array_module(h.data) | ||
if chainer.config.train: | ||
return h + sigma * xp.random.randn(*h.shape) | ||
else: | ||
return h | ||
|
||
|
||
class Generator(chainer.Chain): | ||
|
||
def __init__(self, n_hidden, bottom_width=4, ch=512, wscale=0.02): | ||
super(Generator, self).__init__() | ||
self.n_hidden = n_hidden | ||
self.ch = ch | ||
self.bottom_width = bottom_width | ||
|
||
with self.init_scope(): | ||
w = chainer.initializers.Normal(wscale) | ||
self.l0 = L.Linear(self.n_hidden, bottom_width * bottom_width * ch, | ||
initialW=w) | ||
self.dc1 = L.Deconvolution2D(ch, ch // 2, 4, 2, 1, initialW=w) | ||
self.dc2 = L.Deconvolution2D(ch // 2, ch // 4, 4, 2, 1, initialW=w) | ||
self.dc3 = L.Deconvolution2D(ch // 4, ch // 8, 4, 2, 1, initialW=w) | ||
self.dc4 = L.Deconvolution2D(ch // 8, 3, 3, 1, 1, initialW=w) | ||
self.bn0 = L.BatchNormalization(bottom_width * bottom_width * ch) | ||
self.bn1 = L.BatchNormalization(ch // 2) | ||
self.bn2 = L.BatchNormalization(ch // 4) | ||
self.bn3 = L.BatchNormalization(ch // 8) | ||
|
||
def make_hidden(self, batchsize): | ||
return numpy.random.uniform(-1, 1, (batchsize, self.n_hidden, 1, 1))\ | ||
.astype(numpy.float32) | ||
|
||
def __call__(self, z): | ||
h = F.reshape(F.relu(self.bn0(self.l0(z))), | ||
(len(z), self.ch, self.bottom_width, self.bottom_width)) | ||
h = F.relu(self.bn1(self.dc1(h))) | ||
h = F.relu(self.bn2(self.dc2(h))) | ||
h = F.relu(self.bn3(self.dc3(h))) | ||
x = F.sigmoid(self.dc4(h)) | ||
return x | ||
|
||
|
||
class Discriminator(chainer.Chain): | ||
|
||
def __init__(self, bottom_width=4, ch=512, wscale=0.02): | ||
w = chainer.initializers.Normal(wscale) | ||
super(Discriminator, self).__init__() | ||
with self.init_scope(): | ||
self.c0_0 = L.Convolution2D(3, ch // 8, 3, 1, 1, initialW=w) | ||
self.c0_1 = L.Convolution2D(ch // 8, ch // 4, 4, 2, 1, initialW=w) | ||
self.c1_0 = L.Convolution2D(ch // 4, ch // 4, 3, 1, 1, initialW=w) | ||
self.c1_1 = L.Convolution2D(ch // 4, ch // 2, 4, 2, 1, initialW=w) | ||
self.c2_0 = L.Convolution2D(ch // 2, ch // 2, 3, 1, 1, initialW=w) | ||
self.c2_1 = L.Convolution2D(ch // 2, ch // 1, 4, 2, 1, initialW=w) | ||
self.c3_0 = L.Convolution2D(ch // 1, ch // 1, 3, 1, 1, initialW=w) | ||
self.l4 = L.Linear(bottom_width * bottom_width * ch, 1, initialW=w) | ||
self.bn0_1 = L.BatchNormalization(ch // 4, use_gamma=False) | ||
self.bn1_0 = L.BatchNormalization(ch // 4, use_gamma=False) | ||
self.bn1_1 = L.BatchNormalization(ch // 2, use_gamma=False) | ||
self.bn2_0 = L.BatchNormalization(ch // 2, use_gamma=False) | ||
self.bn2_1 = L.BatchNormalization(ch // 1, use_gamma=False) | ||
self.bn3_0 = L.BatchNormalization(ch // 1, use_gamma=False) | ||
|
||
def __call__(self, x): | ||
h = add_noise(x) | ||
h = F.leaky_relu(add_noise(self.c0_0(h))) | ||
h = F.leaky_relu(add_noise(self.bn0_1(self.c0_1(h)))) | ||
h = F.leaky_relu(add_noise(self.bn1_0(self.c1_0(h)))) | ||
h = F.leaky_relu(add_noise(self.bn1_1(self.c1_1(h)))) | ||
h = F.leaky_relu(add_noise(self.bn2_0(self.c2_0(h)))) | ||
h = F.leaky_relu(add_noise(self.bn2_1(self.c2_1(h)))) | ||
h = F.leaky_relu(add_noise(self.bn3_0(self.c3_0(h)))) | ||
return self.l4(h) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
#!/usr/bin/env python | ||
|
||
from __future__ import print_function | ||
import argparse | ||
import os | ||
|
||
import chainer | ||
from chainer import training | ||
from chainer.training import extensions | ||
from mpi4py import MPI | ||
|
||
from net import Discriminator | ||
from net import Generator | ||
from updater import DCGANUpdater | ||
from visualize import out_generated_image | ||
|
||
import chainermn | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='ChainerMN example: DCGAN') | ||
parser.add_argument('--batchsize', '-b', type=int, default=50, | ||
help='Number of images in each mini-batch') | ||
parser.add_argument('--communicator', type=str, | ||
default='hierarchical', help='Type of communicator') | ||
parser.add_argument('--epoch', '-e', type=int, default=1000, | ||
help='Number of sweeps over the dataset to train') | ||
parser.add_argument('--gpu', '-g', action='store_true', | ||
help='Use GPU') | ||
parser.add_argument('--dataset', '-i', default='', | ||
help='Directory of image files. Default is cifar-10.') | ||
parser.add_argument('--out', '-o', default='result', | ||
help='Directory to output the result') | ||
parser.add_argument('--gen_model', '-r', default='', | ||
help='Use pre-trained generator for training') | ||
parser.add_argument('--dis_model', '-d', default='', | ||
help='Use pre-trained discriminator for training') | ||
parser.add_argument('--n_hidden', '-n', type=int, default=100, | ||
help='Number of hidden units (z)') | ||
parser.add_argument('--seed', type=int, default=0, | ||
help='Random seed of z at visualization stage') | ||
parser.add_argument('--snapshot_interval', type=int, default=1000, | ||
help='Interval of snapshot') | ||
parser.add_argument('--display_interval', type=int, default=100, | ||
help='Interval of displaying log to console') | ||
args = parser.parse_args() | ||
|
||
# Prepare ChainerMN communicator. | ||
|
||
if args.gpu: | ||
if args.communicator == 'naive': | ||
print("Error: 'naive' communicator does not support GPU.\n") | ||
exit(-1) | ||
comm = chainermn.create_communicator(args.communicator) | ||
device = comm.intra_rank | ||
else: | ||
if args.communicator != 'naive': | ||
print('Warning: using naive communicator ' | ||
'because only naive supports CPU-only execution') | ||
comm = chainermn.create_communicator('naive') | ||
device = -1 | ||
|
||
if comm.mpi_comm.rank == 0: | ||
print('==========================================') | ||
print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size())) | ||
if args.gpu: | ||
print('Using GPUs') | ||
print('Using {} communicator'.format(args.communicator)) | ||
print('Num hidden unit: {}'.format(args.n_hidden)) | ||
print('Num Minibatch-size: {}'.format(args.batchsize)) | ||
print('Num epoch: {}'.format(args.epoch)) | ||
print('==========================================') | ||
|
||
# Set up a neural network to train | ||
gen = Generator(n_hidden=args.n_hidden) | ||
dis = Discriminator() | ||
|
||
if device >= 0: | ||
# Make a specified GPU current | ||
chainer.cuda.get_device_from_id(device).use() | ||
gen.to_gpu() # Copy the model to the GPU | ||
dis.to_gpu() | ||
|
||
# Setup an optimizer | ||
def make_optimizer(model, comm, alpha=0.0002, beta1=0.5): | ||
# Create a multi node optimizer from a standard Chainer optimizer. | ||
optimizer = chainermn.create_multi_node_optimizer( | ||
chainer.optimizers.Adam(alpha=alpha, beta1=beta1), comm) | ||
optimizer.setup(model) | ||
optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001), 'hook_dec') | ||
return optimizer | ||
|
||
opt_gen = make_optimizer(gen, comm) | ||
opt_dis = make_optimizer(dis, comm) | ||
|
||
# Split and distribute the dataset. Only worker 0 loads the whole dataset. | ||
# Datasets of worker 0 are evenly split and distributed to all workers. | ||
if comm.rank == 0: | ||
if args.dataset == '': | ||
# Load the CIFAR10 dataset if args.dataset is not specified | ||
train, _ = chainer.datasets.get_cifar10(withlabel=False, | ||
scale=255.) | ||
else: | ||
all_files = os.listdir(args.dataset) | ||
image_files = [f for f in all_files if ('png' in f or 'jpg' in f)] | ||
print('{} contains {} image files' | ||
.format(args.dataset, len(image_files))) | ||
train = chainer.datasets\ | ||
.ImageDataset(paths=image_files, root=args.dataset) | ||
else: | ||
train = None | ||
|
||
train = chainermn.scatter_dataset(train, comm) | ||
|
||
train_iter = chainer.iterators.SerialIterator(train, args.batchsize) | ||
|
||
# Set up a trainer | ||
updater = DCGANUpdater( | ||
models=(gen, dis), | ||
iterator=train_iter, | ||
optimizer={ | ||
'gen': opt_gen, 'dis': opt_dis}, | ||
device=device) | ||
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) | ||
|
||
# Some display and output extensions are necessary only for one worker. | ||
# (Otherwise, there would just be repeated outputs.) | ||
if comm.rank == 0: | ||
snapshot_interval = (args.snapshot_interval, 'iteration') | ||
display_interval = (args.display_interval, 'iteration') | ||
# Save only model parameters. | ||
# `snapshot` extension will save all the trainer module's attribute, | ||
# including `train_iter`. | ||
# However, `train_iter` depends on scattered dataset, which means that | ||
# `train_iter` may be different in each process. | ||
# Here, instead of saving whole trainer module, only the network models | ||
# are saved. | ||
trainer.extend(extensions.snapshot_object( | ||
gen, 'gen_iter_{.updater.iteration}.npz'), | ||
trigger=snapshot_interval) | ||
trainer.extend(extensions.snapshot_object( | ||
dis, 'dis_iter_{.updater.iteration}.npz'), | ||
trigger=snapshot_interval) | ||
trainer.extend(extensions.LogReport(trigger=display_interval)) | ||
trainer.extend(extensions.PrintReport([ | ||
'epoch', 'iteration', 'gen/loss', 'dis/loss', 'elapsed_time', | ||
]), trigger=display_interval) | ||
trainer.extend(extensions.ProgressBar(update_interval=10)) | ||
trainer.extend( | ||
out_generated_image( | ||
gen, dis, | ||
10, 10, args.seed, args.out), | ||
trigger=snapshot_interval) | ||
|
||
# Start the training using pre-trained model, saved by snapshot_object | ||
if args.gen_model: | ||
chainer.serializers.load_npz(args.gen_model, gen) | ||
if args.dis_model: | ||
chainer.serializers.load_npz(args.dis_model, dis) | ||
|
||
# Run the training | ||
trainer.run() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#!/usr/bin/env python | ||
|
||
from __future__ import print_function | ||
|
||
import chainer | ||
import chainer.functions as F | ||
from chainer import Variable | ||
|
||
|
||
class DCGANUpdater(chainer.training.StandardUpdater): | ||
|
||
def __init__(self, *args, **kwargs): | ||
self.gen, self.dis = kwargs.pop('models') | ||
super(DCGANUpdater, self).__init__(*args, **kwargs) | ||
|
||
def loss_dis(self, dis, y_fake, y_real): | ||
batchsize = len(y_fake) | ||
L1 = F.sum(F.softplus(-y_real)) / batchsize | ||
L2 = F.sum(F.softplus(y_fake)) / batchsize | ||
loss = L1 + L2 | ||
chainer.report({'loss': loss}, dis) | ||
return loss | ||
|
||
def loss_gen(self, gen, y_fake): | ||
batchsize = len(y_fake) | ||
loss = F.sum(F.softplus(-y_fake)) / batchsize | ||
chainer.report({'loss': loss}, gen) | ||
return loss | ||
|
||
def update_core(self): | ||
gen_optimizer = self.get_optimizer('gen') | ||
dis_optimizer = self.get_optimizer('dis') | ||
|
||
batch = self.get_iterator('main').next() | ||
x_real = Variable(self.converter(batch, self.device)) / 255. | ||
xp = chainer.cuda.get_array_module(x_real.data) | ||
|
||
gen, dis = self.gen, self.dis | ||
batchsize = len(batch) | ||
|
||
y_real = dis(x_real) | ||
|
||
z = Variable(xp.asarray(gen.make_hidden(batchsize))) | ||
x_fake = gen(z) | ||
y_fake = dis(x_fake) | ||
|
||
dis_optimizer.update(self.loss_dis, dis, y_fake, y_real) | ||
gen_optimizer.update(self.loss_gen, gen, y_fake) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#!/usr/bin/env python | ||
|
||
import os | ||
|
||
import numpy as np | ||
from PIL import Image | ||
|
||
import chainer | ||
import chainer.cuda | ||
from chainer import Variable | ||
|
||
|
||
def out_generated_image(gen, dis, rows, cols, seed, dst): | ||
@chainer.training.make_extension() | ||
def make_image(trainer): | ||
np.random.seed(seed) | ||
n_images = rows * cols | ||
xp = gen.xp | ||
z = Variable(xp.asarray(gen.make_hidden(n_images))) | ||
with chainer.using_config('train', False): | ||
x = gen(z) | ||
x = chainer.cuda.to_cpu(x.data) | ||
np.random.seed() | ||
|
||
x = np.asarray(np.clip(x * 255, 0.0, 255.0), dtype=np.uint8) | ||
_, _, H, W = x.shape | ||
x = x.reshape((rows, cols, 3, H, W)) | ||
x = x.transpose(0, 3, 1, 4, 2) | ||
x = x.reshape((rows * H, cols * W, 3)) | ||
|
||
preview_dir = '{}/preview'.format(dst) | ||
preview_path = preview_dir +\ | ||
'/image{:0>8}.png'.format(trainer.updater.iteration) | ||
if not os.path.exists(preview_dir): | ||
os.makedirs(preview_dir) | ||
Image.fromarray(x).save(preview_path) | ||
return make_image |