-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_modules.py
284 lines (257 loc) · 13.4 KB
/
train_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import os
from collections import defaultdict
from typing import Any
import jax
import numpy as np
import optax
from flax import linen as nn
from flax.training import checkpoints, train_state
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
class TrainState(train_state.TrainState):
"""A simple extension of TrainState to also include batch statistics."""
batch_stats: Any
class TrainerModuleBatch:
"""
Module for summarizing all training functionalities for classification on CIFAR10.
Inputs:
model_name - String of the class name, used for logging and saving
model_class - Class implementing the neural network
model_hparams - Hyperparameters of the model, used as input to model constructor
optimizer_name - String of the optimizer name, supporting ['sgd', 'adam', 'adamw']
optimizer_hparams - Hyperparameters of the optimizer, including learning rate as 'lr'
exmp_imgs - Example imgs, used as input to initialize the model
checkpoint_dir - Directory to save checkpoints
seed - Seed to use in the model initialization
"""
def __init__(self,
model_name : str,
model_class : nn.Module,
model_hparams : dict,
optimizer_name : str,
optimizer_hparams : dict,
exmp_imgs : Any,
checkpoint_dir : str,
seed=42):
super().__init__()
self.model_name = model_name
self.model_class = model_class
self.model_hparams = model_hparams
self.optimizer_name = optimizer_name
self.optimizer_hparams = optimizer_hparams
self.seed = seed
# Create empty model. Note: no parameters yet
self.model = self.model_class(**self.model_hparams)
# Prepare logging
self.checkpoint_dir = checkpoint_dir
self.log_dir = os.path.join(checkpoint_dir, self.model_name)
self.logger = SummaryWriter(log_dir=self.log_dir)
# Create jitted training and eval functions
self.create_functions()
# Initialize model
self.init_model(exmp_imgs)
def create_functions(self):
"""Function to calculate the classification loss and accuracy for a model."""
def calculate_loss(params, batch_stats, batch, train):
"""Calculate loss and accuracy for a batch."""
imgs, labels = batch
# Run model. During training, we need to update the BatchNorm statistics.
outs = self.model.apply({'params': params, 'batch_stats': batch_stats},
imgs,
train=train,
mutable=['batch_stats'] if train else False)
logits, new_model_state = outs if train else (outs, None)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
acc = (logits.argmax(axis=-1) == labels).mean()
return loss, (acc, new_model_state)
def train_step(state, batch):
"""Training function"""
loss_fn = lambda params: calculate_loss(params, state.batch_stats, batch, train=True)
# Get loss, gradients for loss, and other outputs of loss function
ret, grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
loss, acc, new_model_state = ret[0], *ret[1]
# Update parameters and batch statistics
state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
return state, loss, acc
def eval_step(state, batch):
""" Return the accuracy for a single batch"""
_, (acc, _) = calculate_loss(state.params, state.batch_stats, batch, train=False)
return acc
# jit for efficiency
self.train_step = jax.jit(train_step)
self.eval_step = jax.jit(eval_step)
def init_model(self, exmp_imgs):
"""Initialize model."""
if "batch_norm" in self.model_hparams:
assert self.model_hparams["batch_norm"], "Modul for training with batch size, use TrainerModule instead"
init_rng = jax.random.PRNGKey(self.seed)
variables = self.model.init(init_rng, exmp_imgs, train=True)
self.init_params, self.init_batch_stats = variables['params'], variables['batch_stats']
self.state = None
def init_optimizer(self, num_epochs, num_steps_per_epoch):
"""Initialize learning rate schedule and optimizer."""
if self.optimizer_name.lower() == 'adam':
opt_class = optax.adam
elif self.optimizer_name.lower() == 'adamw':
opt_class = optax.adamw
elif self.optimizer_name.lower() == 'sgd':
opt_class = optax.sgd
else:
assert False, f'Unknown optimizer "{opt_class}"'
# We decrease the learning rate by a factor of 0.1 after 60% and 85% of the training
lr_schedule = optax.piecewise_constant_schedule(
init_value=self.optimizer_hparams.pop('lr'),
boundaries_and_scales=
{int(num_steps_per_epoch*num_epochs*0.6): 0.1,
int(num_steps_per_epoch*num_epochs*0.85): 0.1}
)
# Clip gradients at max value, and evt. apply weight decay
transf = [optax.clip(1.0)]
if opt_class == optax.sgd and 'weight_decay' in self.optimizer_hparams: # wd is integrated in adamw
transf.append(optax.add_decayed_weights(self.optimizer_hparams.pop('weight_decay')))
optimizer = optax.chain(
*transf,
opt_class(lr_schedule, **self.optimizer_hparams)
)
# Initialize training state
self.state = TrainState.create(apply_fn=self.model.apply,
params=self.init_params if self.state is None else self.state.params,
batch_stats=self.init_batch_stats if self.state is None else self.state.batch_stats,
tx=optimizer)
def train_model(self, train_loader, val_loader, num_epochs=200):
"""Train model for defined number of epochs."""
# We first need to create optimizer and the scheduler for the given number of epochs
self.init_optimizer(num_epochs, len(train_loader))
# Track best eval accuracy
best_eval = 0.0
for epoch_idx in tqdm(range(1, num_epochs+1)):
self.train_epoch(train_loader, epoch=epoch_idx)
if epoch_idx % 2 == 0:
eval_acc = self.eval_model(val_loader)
self.logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx)
if eval_acc >= best_eval:
best_eval = eval_acc
self.save_model(step=epoch_idx)
self.logger.flush()
def train_epoch(self, train_loader, epoch):
"""Train model for one epoch, and log avg loss and accuracy."""
metrics = defaultdict(list)
for batch in tqdm(train_loader, desc='Training', leave=False):
self.state, loss, acc = self.train_step(self.state, batch)
metrics['loss'].append(loss)
metrics['acc'].append(acc)
for key in metrics:
avg_val = np.stack(jax.device_get(metrics[key])).mean()
self.logger.add_scalar('train/'+key, avg_val, global_step=epoch)
def eval_model(self, data_loader):
"""Test model on all images of a data loader and return avg loss."""
correct_class, count = 0, 0
for batch in data_loader:
acc = self.eval_step(self.state, batch)
correct_class += acc * batch[0].shape[0]
count += batch[0].shape[0]
eval_acc = (correct_class / count).item()
return eval_acc
def save_model(self, step=0):
"""Save current model at certain training iteration."""
checkpoints.save_checkpoint(ckpt_dir=self.log_dir,
target={'params': self.state.params,
'batch_stats': self.state.batch_stats},
step=step,
overwrite=True)
def load_model(self, pretrained=False):
"""Load model. We use different checkpoint for pretrained models."""
if not pretrained:
state_dict = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=None)
else:
state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(self.checkpoint_dir, f'{self.model_name}.ckpt'), target=None)
self.state = TrainState.create(apply_fn=self.model.apply,
params=state_dict['params'],
batch_stats=state_dict['batch_stats'],
tx=self.state.tx if self.state else optax.sgd(0.1) # Default optimizer
)
def checkpoint_exists(self):
# Check whether a pretrained model exist for this autoencoder
return os.path.isfile(os.path.join(self.checkpoint_dir, f'{self.model_name}.ckpt'))
class TrainerModule(TrainerModuleBatch):
"""
Module without batch normalizing for summarizing all training functionalities
for classification on CIFAR10.
"""
def create_functions(self):
# Function to calculate the classification loss and accuracy for a model
def calculate_loss(params, batch, train):
imgs, labels = batch
# Run model. During training, we need to update the BatchNorm statistics.
logits = self.model.apply({'params': params}, imgs, train=train)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
acc = (logits.argmax(axis=-1) == labels).mean()
return loss, acc
# Training function
def train_step(state, batch):
loss_fn = lambda params: calculate_loss(params, batch, train=True)
# Get loss, gradients for loss, and other outputs of loss function
ret, grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
loss, acc = ret
# Update parameters and batch statistics
state = state.apply_gradients(grads=grads)
return state, loss, acc
# Eval function
def eval_step(state, batch):
# Return the accuracy for a single batch
_, acc = calculate_loss(state.params, batch, train=False)
return acc
# jit for efficiency
self.train_step = jax.jit(train_step)
self.eval_step = jax.jit(eval_step)
def init_model(self, exmp_imgs):
"""Initialize model."""
assert not self.model_hparams["batch_norm"], "Modul for training without batch size, use TrainerModuleBatch instead"
init_rng = jax.random.PRNGKey(self.seed)
self.init_params = self.model.init(init_rng, exmp_imgs, train=True)['params']
self.state = None
def init_optimizer(self, num_epochs, num_steps_per_epoch):
"""Initialize learning rate schedule and optimizer."""
if self.optimizer_name.lower() == 'adam':
opt_class = optax.adam
elif self.optimizer_name.lower() == 'adamw':
opt_class = optax.adamw
elif self.optimizer_name.lower() == 'sgd':
opt_class = optax.sgd
else:
assert False, f'Unknown optimizer "{opt_class}"'
# We decrease the learning rate by a factor of 0.1 after 60% and 85% of the training
lr_schedule = optax.piecewise_constant_schedule(
init_value=self.optimizer_hparams.pop('lr'),
boundaries_and_scales=
{int(num_steps_per_epoch*num_epochs*0.6): 0.1,
int(num_steps_per_epoch*num_epochs*0.85): 0.1}
)
# Clip gradients at max value, and evt. apply weight decay
transf = [optax.clip(1.0)]
if opt_class == optax.sgd and 'weight_decay' in self.optimizer_hparams: # wd is integrated in adamw
transf.append(optax.add_decayed_weights(self.optimizer_hparams.pop('weight_decay')))
optimizer = optax.chain(
*transf,
opt_class(lr_schedule, **self.optimizer_hparams)
)
# Initialize training state
self.state = train_state.TrainState.create(apply_fn=self.model.apply,
params=self.init_params if self.state is None else self.state.params,
tx=optimizer)
def save_model(self, step=0):
"""Save current model at certain training iteration."""
checkpoints.save_checkpoint(ckpt_dir=self.log_dir,
target={'params': self.state.params},
step=step,
overwrite=True)
def load_model(self, pretrained=False):
"""Load model. We use different checkpoint for pretrained models."""
if not pretrained:
state_dict = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=None)
else:
state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(self.checkpoint_dir, f'{self.model_name}.ckpt'), target=None)
self.state = train_state.TrainState.create(apply_fn=self.model.apply,
params=state_dict['params'],
tx=self.state.tx if self.state else optax.sgd(0.1) # Default optimizer
)