-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcallbacks.py
70 lines (59 loc) · 2.45 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import tensorflow as tf
class PerformanceCallback(tf.keras.callbacks.Callback):
def __init__(self, same_line: bool = True, early_stop_patience: int = 0):
"""
Custom performance monitoring callback with early stopping
:param same_line: whether to print on the same line
:param early_stop_patience: early stopping patience (a negative value disables early stopping)
"""
super().__init__()
# save printing configuration
self.same_line = same_line
# early stopping configuration
self.patience = early_stop_patience
self.nan_inf = False
self.wait = 0
self.stopped_epoch = 0
self.best_weights = None
self.best = None
def on_train_begin(self, logs=None):
# reinitialization code that allows instance to be reused
self.nan_inf = False
self.wait = 0
self.stopped_epoch = 0
self.best = np.inf
self.best_weights = None
def on_epoch_end(self, epoch, logs=None):
# update string
update_string = 'Train Loss = {:.4f} | '.format(logs['loss'])
update_string += 'Validation Loss = {:.4f} | '.format(logs['val_loss'])
# early stopping logic
if self.patience > 0:
if tf.less(logs['val_loss'], self.best):
self.best = logs['val_loss']
self.wait = 0
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
update_string += 'Best Validation Loss = {:.4f} | '.format(self.best)
update_string += 'Patience: {:d}/{:d}'.format(self.wait, self.patience)
# test for NaN and Inf
if tf.math.is_nan(logs['loss']) or tf.math.is_inf(logs['loss']):
self.nan_inf = True
self.model.stop_training = True
# print update
if self.same_line:
print('\r' + self.model.name + ' Epoch {:d} | '.format(epoch + 1) + update_string, end='')
else:
print(update_string)
def on_train_end(self, logs=None):
if self.same_line:
print('')
if self.nan_inf:
print('Epoch {:d}: NaN or Inf detected!'.format(self.stopped_epoch + 1))
if self.stopped_epoch > 0:
self.model.set_weights(self.best_weights)