-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathproject_base.py
650 lines (523 loc) · 30.5 KB
/
project_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
# -*- coding: utf-8 -*-
"""project.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1-fOP3nDvx51pIPmr0hWUA-Q3UQ8sMmzi
"""
import torch.nn as nn
import torch.nn.functional as F
import torch
## 定义参数初始化函数
##############################
## 残差块儿ResidualBlock
##############################
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
# kernel 3*3
# size before and after are the same
self.block = nn.Sequential( ## block = [pad + conv + norm + relu + pad + conv + norm]
nn.ReflectionPad2d(1), ## ReflectionPad2d():利用输入边界的反射来填充输入张量
nn.Conv2d(in_features, in_features, 3), ## 卷积
nn.InstanceNorm2d(in_features), ## InstanceNorm2d():在图像像素上对HW做归一化,用在风格化迁移
nn.ReLU(inplace=True), ## 非线性激活
nn.ReflectionPad2d(1), ## ReflectionPad2d():利用输入边界的反射来填充输入张量
nn.Conv2d(in_features, in_features, 3), ## 卷积
nn.InstanceNorm2d(in_features), ## InstanceNorm2d():在图像像素上对HW做归一化,用在风格化迁移
)
def forward(self, x): ## 输入为 一张图像
return x + self.block(x) ## 输出为 图像加上网络的残差输出
##############################
## 生成器网络GeneratorResNet
##############################
class GeneratorResNet(nn.Module):
def __init__(self, input_shape, num_residual_blocks): ## (input_shape = (3, 256, 256), num_residual_blocks = 9)
super(GeneratorResNet, self).__init__()
channels = input_shape[0] ## 输入通道数channels = 3
# size are always unchanged during the net
## 初始化网络结构
out_features = 64 ## 输出特征数out_features = 64
model = [ ## model = [Pad + Conv + Norm + ReLU]
nn.ReflectionPad2d(channels), ## ReflectionPad2d(3):利用输入边界的反射来填充输入张量
nn.Conv2d(channels, out_features, 7), ## Conv2d(3, 64, 7)
nn.InstanceNorm2d(out_features), ## InstanceNorm2d(64):在图像像素上对HW(height and width)做归一化,用在风格化迁移
nn.ReLU(inplace=True), ## 非线性激活
]
in_features = out_features ## in_features = 64
## 下采样,循环2次
for _ in range(2):
out_features *= 2 ## out_features = 128 -> 256
model += [ ## (Conv + Norm + ReLU) * 2
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features ## in_features = 256
# 残差块儿,循环9次
for _ in range(num_residual_blocks):
model += [ResidualBlock(out_features)] ## model += [pad + conv + norm + relu + pad + conv + norm]
# 上采样两次
for _ in range(2):
out_features //= 2 ## out_features = 128 -> 64
model += [ ## model += [Upsample + conv + norm + relu]
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features ## out_features = 64
## 网络输出层 ## model += [pad + conv + tanh]
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()] ## 将(3)的数据每一个都映射到[-1, 1]之间
self.model = nn.Sequential(*model)
def forward(self, x): ## 输入(1, 3, 256, 256)
return self.model(x) ## 输出(1, 3, 256, 256)
##############################
# Discriminator
##############################
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape ## input_shape:(3, 256, 256)
# Calculate output shape of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4) ## output_shape = (1, 16, 16)
def discriminator_block(in_filters, out_filters, normalize=True): ## 鉴别器块儿
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] ## layer += [conv + norm + relu]
if normalize: ## 每次卷积尺寸会缩小一半,共卷积了4次
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(channels, 64, normalize=False), ## layer += [conv(3, 64) + relu]
*discriminator_block(64, 128), ## layer += [conv(64, 128) + norm + relu]
*discriminator_block(128, 256), ## layer += [conv(128, 256) + norm + relu]
*discriminator_block(256, 512), ## layer += [conv(256, 512) + norm + relu]
nn.ZeroPad2d((1, 0, 1, 0)), ## layer += [pad]
nn.Conv2d(512, 1, 4, padding=1) ## layer += [conv(512, 1)]
)
def forward(self, img): ## 输入(1, 3, 256, 256)
return self.model(img) ## 输出(1, 1, 16, 16)
def weights_init_normal(m):
classname = m.__class__.__name__ ## m作为一个形参,原则上可以传递很多的内容, 为了实现多实参传递,每一个moudle要给出自己的name. 所以这句话就是返回m的名字.
if classname.find("Conv") != -1: ## find():实现查找classname中是否含有Conv字符,没有返回-1;有返回0.
torch.nn.init.normal_(m.weight.data, 0.0, 0.02) ## m.weight.data表示需要初始化的权重。nn.init.normal_():表示随机初始化采用正态分布,均值为0,标准差为0.02.
if hasattr(m, "bias") and m.bias is not None: ## hasattr():用于判断m是否包含对应的属性bias, 以及bias属性是否不为空.
torch.nn.init.constant_(m.bias.data, 0.0) ## nn.init.constant_():表示将偏差定义为常量0.
elif classname.find("BatchNorm2d") != -1: ## find():实现查找classname中是否含有BatchNorm2d字符,没有返回-1;有返回0.
torch.nn.init.normal_(m.weight.data, 1.0, 0.02) ## m.weight.data表示需要初始化的权重. nn.init.normal_():表示随机初始化采用正态分布,均值为0,标准差为0.02.
torch.nn.init.constant_(m.bias.data, 0.0) ## nn.init.constant_():表示将偏差定义为常量0.
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
## 如果输入的数据集是灰度图像,将图片转化为rgb图像(本次采用的facades不需要这个)
def to_rgb(image):
rgb_image = Image.new("RGB", image.size)
rgb_image.paste(image)
return rgb_image
## 构建数据集
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, unaligned=False, mode="train"): ## (root = "./datasets/facades", unaligned=True:非对其数据)
self.transform = transforms.Compose(transforms_) ## transform变为tensor数据
self.unaligned = unaligned
self.files_A = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*")) ## "./datasets/facades/trainA/*.*"
self.files_B = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*")) ## "./datasets/facades/trainB/*.*"
def __getitem__(self, index):
image_A = Image.open(self.files_A[index % len(self.files_A)]) ## 在A中取一张照片
if self.unaligned: ## 如果采用非配对数据,在B中随机取一张
image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
else:
image_B = Image.open(self.files_B[index % len(self.files_B)])
# 如果是灰度图,把灰度图转换为RGB图
if image_A.mode != "RGB":
image_A = to_rgb(image_A)
if image_B.mode != "RGB":
image_B = to_rgb(image_B)
# 把RGB图像转换为tensor图, 方便计算,返回字典数据
item_A = self.transform(image_A)
item_B = self.transform(image_B)
return {"A": item_A, "B": item_B}
## 获取A,B数据的长度
def __len__(self):
return max(len(self.files_A), len(self.files_B))
import random
import time
import datetime
import sys
import torch
import numpy as np
from torchvision.utils import save_image
## 先前生成的样本的缓冲区
class ReplayBuffer:
def __init__(self, max_size=50):
assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
self.max_size = max_size
self.data = []
def push_and_pop(self, data): ## 放入一张图像,再从buffer里取一张出来
to_return = [] ## 确保数据的随机性,判断真假图片的鉴别器识别率
for element in data.data:
element = torch.unsqueeze(element, 0)
if len(self.data) < self.max_size: ## 最多放入50张,没满就一直添加
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0, 1) > 0.5: ## 满了就1/2的概率从buffer里取,或者就用当前的输入图片
i = random.randint(0, self.max_size - 1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return torch.cat(to_return)
## 设置学习率为初始学习率乘以给定lr_lambda函数的值
class LambdaLR:
def __init__(self, n_epochs, offset, decay_start_epoch): ## (n_epochs = 50, offset = epoch, decay_start_epoch = 30)
assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!" ## 断言,要让n_epochs > decay_start_epoch 才可以
self.n_epochs = n_epochs
self.offset = offset
self.decay_start_epoch = decay_start_epoch
def step(self, epoch): ## return 1-max(0, epoch - 30) / (50 - 30)
return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
# Commented out IPython magic to ensure Python compatibility.
import argparse
import os
from tkinter import Image
import numpy as np
import math
import itertools
import datetime
import time
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
import torch
from PIL import Image
## 超参数配置
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=4, help="epoch to start training from") # 开始轮
parser.add_argument("--n_epochs", type=int, default=25, help="number of epochs of training")# 总轮数
parser.add_argument("--dataset_name", type=str, default="facades", help="name of the dataset")## ../input/facades-dataset
parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0003, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=3, help="epoch from which to start lr decay")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight")
opt = parser.parse_args()
#opt = parser.parse_args(args=[]) ## 在colab中运行时,换为此行
print(opt)
## 创建文件夹
os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("save/%s" % opt.dataset_name, exist_ok=True)
## input_shape:(3, 256, 256)
input_shape = (opt.channels, opt.img_height, opt.img_width)
## 创建生成器,判别器对象
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
## 损失函数
## MES 二分类的交叉熵
## L1loss 相比于L2 Loss保边缘
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():
print(torch.cuda.is_available())
G_AB = G_AB.cuda()
G_BA = G_BA.cuda()
D_A = D_A.cuda()
D_B = D_B.cuda()
criterion_GAN.cuda()
criterion_cycle.cuda()
criterion_identity.cuda()
## 如果epoch == 0,初始化模型参数; 如果epoch == n, 载入训练到第n轮的预训练模型
if opt.epoch != 0:
# 载入训练到第n轮的预训练模型
G_AB.load_state_dict(torch.load("save/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
G_BA.load_state_dict(torch.load("save/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
D_A.load_state_dict(torch.load("save/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
D_B.load_state_dict(torch.load("save/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
# 初始化模型参数
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)
## 定义优化函数,优化函数的学习率为0.0003
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
## 学习率更行进程
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
## 先前生成的样本的缓冲区
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
## 图像 transformations
transforms_ = [
transforms.Resize(int(opt.img_height * 1.12)), ## 图片放大1.12倍
transforms.RandomCrop((opt.img_height, opt.img_width)), ## 随机裁剪成原来的大小
transforms.RandomHorizontalFlip(), ## 随机水平翻转
transforms.ToTensor(), ## 变为Tensor数据
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ## 正则化
]
## Training data loader
dataloader = DataLoader( ## 改成自己存放文件的目录
ImageDataset("vangogh2photo", transforms_=transforms_, unaligned=True), ## "./datasets/facades" , unaligned:设置非对其数据
batch_size=opt.batch_size, ## batch_size = 1
shuffle=True,
num_workers=0,
pin_memory=False
)
## Test data loader
val_dataloader = DataLoader(
ImageDataset("vangogh2photo", transforms_=transforms_, unaligned=True, mode="test"), ## "./datasets/facades"
batch_size=5,
shuffle=True,
num_workers=0,
pin_memory=False
)
## 每间隔100次打印图片
def sample_images(batches_done): ## (100/200/300/400...)
"""保存测试集中生成的样本"""
imgs = next(iter(val_dataloader)) ## 取一张图像
G_AB.eval()
G_BA.eval()
real_A = (imgs["A"]).cuda() ## 取一张真A
fake_B = G_AB(real_A) ## 用真A生成假B
real_B = (imgs["B"]).cuda() ## 去一张真B
fake_A = G_BA(real_B) ## 用真B生成假A
# Arange images along x-axis
## make_grid():用于把几个图像按照网格排列的方式绘制出来
real_A = make_grid(real_A, nrow=5, normalize=True)
real_B = make_grid(real_B, nrow=5, normalize=True)
fake_A = make_grid(fake_A, nrow=5, normalize=True)
fake_B = make_grid(fake_B, nrow=5, normalize=True)
# Arange images along y-axis
## 把以上图像都拼接起来,保存为一张大图片
image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)
def train():
# ----------
# Training
# ----------
prev_time = time.time() ## 开始时间
for epoch in range(opt.epoch, opt.n_epochs): ## for epoch in (0, 50)
print('epoch: '+ str(epoch))
for i, batch in enumerate(dataloader): ## batch is a dict, batch['A']:(1, 3, 256, 256), batch['B']:(1, 3, 256, 256)
## 读取数据集中的真图片
real_A = (batch["A"]).cuda() ## 真图像A
real_B = (batch["B"]).cuda() ## 真图像B
## 全真,全假的标签
valid = torch.ones((real_A.size(0), *D_A.output_shape)).cuda() ## 定义真实的图片label为1 ones((1, 1, 16, 16))
fake = torch.zeros((real_A.size(0), *D_A.output_shape)).cuda() ## 定义假的图片的label为0 zeros((1, 1, 16, 16))
## -----------------
## Train Generator
## 原理:目的是希望生成的假的图片被判别器判断为真的图片,
## 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
## 反向传播更新的参数是生成网络里面的参数,
## 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的
## -----------------
G_AB.train()
G_BA.train()
## Identity loss ## A风格的图像 放在 B -> A 生成器中,生成的图像也要是 A风格
loss_id_A = criterion_identity(G_BA(real_A), real_A) ## loss_id_A就是把图像A1放入 B2A 的生成器中,那当然生成图像A2的风格也得是A风格, 要让A1,A2的差距很小
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2 ## Identity loss
## GAN loss
fake_B = G_AB(real_A) ## 用真图像A生成的假图像B
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) ## 用B鉴别器鉴别假图像B,训练生成器的目的就是要让鉴别器以为假的是真的,假的太接近真的让鉴别器分辨不出来
fake_A = G_BA(real_B) ## 用真图像B生成的假图像A
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid) ## 用A鉴别器鉴别假图像A,训练生成器的目的就是要让鉴别器以为假的是真的,假的太接近真的让鉴别器分辨不出来
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 ## GAN loss
# Cycle loss 循环一致性损失
recov_A = G_BA(fake_B) ## 之前中realA 通过 A -> B 生成的假图像B,再经过 B -> A ,使得fakeB 得到的循环图像recovA,
loss_cycle_A = criterion_cycle(recov_A, real_A) ## realA和recovA的差距应该很小,以保证A,B间不仅风格有所变化,而且图片对应的的细节也可以保留
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss ## 就是上面所有的损失都加起来
loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
optimizer_G.zero_grad() ## 在反向传播之前,先将梯度归0
loss_G.backward() ## 将误差反向传播
optimizer_G.step() ## 更新参数
## -----------------------
## Train Discriminator A
## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
## -----------------------
## 真的图像判别为真
loss_real = criterion_GAN(D_A(real_A), valid)
## 假的图像判别为假(从之前的buffer缓存中随机取一张)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
optimizer_D_A.zero_grad() ## 在反向传播之前,先将梯度归0
loss_D_A.backward() ## 将误差反向传播
optimizer_D_A.step() ## 更新参数
## -----------------------
## Train Discriminator B
## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
## -----------------------
# 真的图像判别为真
loss_real = criterion_GAN(D_B(real_B), valid)
## 假的图像判别为假(从之前的buffer缓存中随机取一张)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
optimizer_D_B.zero_grad() ## 在反向传播之前,先将梯度归0
loss_D_B.backward() ## 将误差反向传播
optimizer_D_B.step() ## 更新参数
loss_D = (loss_D_A + loss_D_B) / 2
## ----------------------
## 打印日志Log Progress
## ----------------------
## 确定剩下的大约时间 假设当前 epoch = 5, i = 100
batches_done = epoch * len(dataloader) + i ## 已经训练了多长时间 5 * 400 + 100 次
batches_left = opt.n_epochs * len(dataloader) - batches_done ## 还剩下 50 * 400 - 2100 次
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) ## 还需要的时间 time_left = 剩下的次数 * 每次的时间
prev_time = time.time()
# Print log
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
% (
epoch,
opt.n_epochs,
i,
len(dataloader),
loss_D.item(),
loss_G.item(),
loss_GAN.item(),
loss_cycle.item(),
loss_identity.item(),
time_left,
)
)
# 每训练100张就保存一组测试集中的图片
if batches_done % opt.sample_interval == 0:
sample_images(batches_done)
# 更新学习率
lr_scheduler_G.step()
lr_scheduler_D_A.step()
lr_scheduler_D_B.step()
## 训练结束后,保存模型
torch.save(G_AB.state_dict(), "save/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
torch.save(G_BA.state_dict(), "save/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_A.state_dict(), "save/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_B.state_dict(), "save/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
print("save my model finished !!")
# ## 每间隔几个epoch保存一次模型
# if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
# # Save model checkpoints
# torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
# torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
# torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
# torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
train()
def test():
## 超参数设置
parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=2, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='vangogh2photo', help='root directory of the dataset')
parser.add_argument('--channels', type=int, default=3, help='number of channels of input data')
parser.add_argument('--n_residual_blocks', type=int, default=9, help='number of channels of output data')
parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)')
parser.add_argument('--cuda', type=bool, default=True, help='use GPU computation')
parser.add_argument('--n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation')
parser.add_argument('--generator_A2B', type=str, default='save/facades/G_AB_4.pth', help='A2B generator checkpoint file')
parser.add_argument('--generator_B2A', type=str, default='save/facades/G_BA_4.pth', help='B2A generator checkpoint file')
opt = parser.parse_args(args=[])
print(opt)
#################################
## test准备工作 ##
#################################
## input_shape:(3, 256, 256)
input_shape = (opt.channels, opt.size, opt.size)
## 创建生成器,判别器对象
netG_A2B = GeneratorResNet(input_shape, opt.n_residual_blocks)
netG_B2A = GeneratorResNet(input_shape, opt.n_residual_blocks)
## 使用cuda
if opt.cuda:
netG_A2B.cuda()
netG_B2A.cuda()
## 载入训练模型参数
netG_A2B.load_state_dict(torch.load(opt.generator_A2B))
netG_B2A.load_state_dict(torch.load(opt.generator_B2A))
## 设置为测试模式
netG_A2B.eval()
netG_B2A.eval()
## 创建一个tensor数组
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_A = Tensor(opt.batchSize, opt.channels, opt.size, opt.size)
input_B = Tensor(opt.batchSize, opt.channels, opt.size, opt.size)
'''构建测试数据集'''
transforms_ = [ transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, mode='test'),
batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu)
#################################
## test开始 ##
#################################
'''如果文件路径不存在, 则创建一个 (存放测试输出的图片)'''
if not os.path.exists('output/A'):
os.makedirs('output/A')
if not os.path.exists('output/B'):
os.makedirs('output/B')
for i, batch in enumerate(dataloader):
## 输入数据 real
real_A = input_A.copy_(batch['A'])
real_B = input_B.copy_(batch['B'])
print(real_A.shape)
## 通过生成器生成的 fake
fake_B = 0.5*(netG_A2B(real_A).data + 1.0)
fake_A = 0.5*(netG_B2A(real_B).data + 1.0)
## 保存图片
save_image(fake_A, 'output/A/%04d.png' % (i+1))
save_image(fake_B, 'output/B/%04d.png' % (i+1))
print('processing (%04d)-th image...' % (i))
print("测试完成")
test()
transform = transforms.Compose(transforms_)
input_shape = (3, 256, 256)
netG_A2B = GeneratorResNet(input_shape, 9)
## 使用cuda
netG_A2B.cuda()
## 载入训练模型参数
netG_A2B.load_state_dict(torch.load('save/facades/G_AB_4.pth'))
## 设置为测试模式
netG_A2B.eval()
## 创建一个tensor数组
Tensor = torch.cuda.FloatTensor
input_A = Tensor(1, 3, 256, 256)
image_A = Image.open('DSC00166.jpg')
image_A = image_A.resize((256,256),Image.ANTIALIAS)
#image_A = np.array([image_A])
image_A=transform(image_A)
image_A = torch.reshape(image_A,(1,3,256,256))
print(image_A.shape)
real_A = image_A.cuda()
fake_B = 0.5*(netG_A2B(real_A).data + 1.0)
save_image(fake_B, 'DSC00166.png')
real_A
image_A