-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrainer.py
72 lines (57 loc) · 2.13 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from network import GAN_Denoiser
class Trainer():
"""
Trainer creates the model and optimizers, and uses them to
updates the weights of the network while reporting losses
and the latest visuals to visualize the progress in training.
"""
def __init__(self, args):
# Save args
self.args = args
self.model = GAN_Denoiser(args)
self.generated = None
self.loss = None
# Create optimizers
self.optimizer_G, self.optimizer_D = \
self.model.create_optimizers()
def run_generator_one_step(self, data, warp):
self.optimizer_G.zero_grad()
g_losses, generated = self.model(data, warp, mode='generator')
g_loss = sum(g_losses.values()).mean()
g_loss.backward()
self.optimizer_G.step()
self.g_losses = g_losses
self.generated = generated
def run_validation(self, data, warp):
return self.model(data, warp, mode='inference')
def run_discriminator_one_step(self):
self.optimizer_D.zero_grad()
d_losses = self.model(self.generated, mode='discriminator')
d_loss = sum(d_losses.values()).mean()
d_loss.backward()
self.optimizer_D.step()
self.d_losses = d_losses
def get_latest_losses(self):
return {**self.g_losses, **self.d_losses}
def get_latest_generated(self):
return self.generated
def save(self, epoch):
self.model.save(epoch)
def get_lr(self, optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def start_epoch(self):
return self.model.start_epoch
def reset_loss(self):
self.loss = {'Reconstruction': 0,
'GAN': 0,
'GAN_Feat': 0,
'VGG': 0,
'D_Fake': 0,
'D_Real': 0}
def append_loss(self):
for (key, value) in self.get_latest_losses().items():
self.loss[key] += value.item()
def normalize_loss(self):
for (key, value) in self.loss.items():
self.loss[key] /= self.args.val_freq