-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathAADN_GCANet_defense.py
152 lines (127 loc) · 6.12 KB
/
AADN_GCANet_defense.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# -*- coding: utf-8 -*-
import torch.optim as optim
import torch.nn as nn
import torch
import methods.GCA.GCA as GCANet
import methods.GCA.options_GCANet_defense as options_GCANet_defense
from defense_utils.dataset.RESIDEDataset import RESIDE_Dataset
import os
# from defense_utils import save
# from defense_utils.metric import cal_psnr_ssim
from defense_utils.loss.loss_writer import LossWriter
config = options_GCANet_defense.Options().parse()
device = torch.device("cuda")
data_root_train = None
data_root_val = None
if_identity_name = None
if config.dataset == "FoggyCity":
data_root_train = config.data_root_train_FoggyCity
data_root_val = config.data_root_val_FoggyCity
if_identity_name = True
else:
raise ValueError("dataset not support")
img_size = [config.img_h, config.img_w]
train_dataset = RESIDE_Dataset(data_root_train, img_size=img_size, if_train=True, if_identity_name=if_identity_name)
val_dataset = RESIDE_Dataset(data_root_val, img_size=img_size, if_train=False, if_identity_name=if_identity_name)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=config.train_batch_size, shuffle=True,
num_workers=config.num_workers, pin_memory=True,
drop_last=True)
# the val_loader can be deleted in this code if there is no val dataset.
# val_loader = torch.utils.data.DataLoader(val_dataset,
# batch_size=config.val_batch_size, shuffle=False,
# num_workers=config.num_workers, pin_memory=True,
# drop_last=True)
res_dir = os.path.join("results_train/GCANet/", config.results_dir)
if not os.path.exists(res_dir):
os.mkdir(res_dir)
os.mkdir(os.path.join(res_dir, "image_temp_test"))
os.mkdir(os.path.join(res_dir, "models"))
os.mkdir(os.path.join(res_dir, "ssim_psnr"))
os.mkdir(os.path.join(res_dir, "loss"))
loss_writer = LossWriter(os.path.join(res_dir, "loss"))
generator = GCANet.GCANet(in_c=config.inc,
out_c=config.outc,
only_residual=config.only_residual).cuda()
optimizer = optim.Adam(generator.parameters(),
lr=config.g_lr, betas=(config.beta1, config.beta2),
# weight_decay=config.weight_decay_generator)
)
loss_func = None
if config.loss == "L1":
loss_func = nn.L1Loss()
elif config.loss == "L2":
loss_func = nn.MSELoss()
iteration = 0
aadn_criterion = nn.MSELoss()
generator.load_state_dict(torch.load(config.ck_path))
from defense_utils.online_attack import attack_predict_or_gt
teacher = GCANet.GCANet(in_c=config.inc,
out_c=config.outc,
only_residual=config.only_residual).cuda()
teacher.load_state_dict(torch.load(config.ck_path))
for epoch in range(config.total_epoches):
generator.train()
for data in train_loader:
image_haze = data["hazy"].to(device)
image_clear = data["gt"].to(device)
# #################################################
delta = None
# The epsilon, alpha and attack_iters can be dynamic generated
if config.att_type == "predict_mse":
teacher_pred = teacher(image_haze).detach().clone()
delta = attack_predict_or_gt(model=generator, hazy=image_haze, label=teacher_pred,
epsilon=8, alpha=2, attack_iters=10, criterion=aadn_criterion)
generated_image_attack = generator(image_haze + delta)
g_loss_attack = loss_func(generated_image_attack, teacher_pred)
elif config.att_type == "gt":
delta = attack_predict_or_gt(model=generator, hazy=image_haze, label=image_clear.clone(),
epsilon=8, alpha=2, attack_iters=10, criterion=aadn_criterion)
generated_image_attack = generator(image_haze + delta)
g_loss_attack = loss_func(generated_image_attack, image_clear)
generated_image_ori = generator(image_haze)
g_loss_ori = loss_func(generated_image_ori, image_clear)
g_loss = g_loss_attack + g_loss_ori
optimizer.zero_grad()
g_loss.backward()
optimizer.step()
loss_writer.add("g_loss", g_loss.item(), iteration)
iteration += 1
if iteration % 100 == 0:
print("Iter {}, Loss is {}".format(iteration, g_loss.item()))
# #################################################
if epoch > config.total_epoches - 10:
torch.save(generator.state_dict(),
os.path.join(res_dir, "models", str(epoch) + ".pth"))
# generator.eval()
# with torch.no_grad():
# num_samples = 0
# total_ssim = 0
# total_psnr = 0
# for data in val_loader:
# image_haze = data["hazy"].to(device)
# image_clear = data["gt"].to(device)
#
# out = generator(image_haze)
# out = torch.clamp(out, min=0, max=1)
#
# num_samples += 1
# total_psnr += cal_psnr_ssim.cal_batch_psnr(pred=out,
# gt=image_clear)
# total_ssim += cal_psnr_ssim.cal_batch_ssim(pred=out,
# gt=image_clear)
#
# out_cat = torch.cat((image_haze, out, image_clear), dim=3)
# save.save_image(image_tensor=out_cat[0],
# out_name=os.path.join(res_dir, "image_temp_test", data["name"][0][:-4] + ".png"))
#
# psnr = total_psnr / num_samples
# ssim = total_ssim / num_samples
# with open(os.path.join(res_dir, "ssim_psnr/metric.txt"), mode='a') as f:
# info = str(epoch) + " " + str(ssim) + " " + str(psnr) + "\n"
# f.write(info)
# f.close()
#
# print("GCANet: ||iterations: {}||, ||PSNR {:.4}||, ||SSIM {:.4}||".format(iteration,
# psnr,
# ssim))