-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkl_scheduler.py
65 lines (49 loc) · 2.27 KB
/
kl_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from model.utils import device
class KLScheduler:
def __init__(
self,
kl_warm_steps,
model,
current_step=0,
min_kl_coeff=0 ):
self.current_step = current_step
self.min_kl_coeff = min_kl_coeff
self.kl_warm_steps = kl_warm_steps
number_of_splits = model.initial_splits_per_scale
# we are going to multiplie each kl_loss by a weight that depends on the scale it belongs to
# bigger dimension -> bigger kl_multiplier
kl_multiplier = []
for i in range(model.number_of_scales):
for k in range(number_of_splits):
kl_multiplier.append((4**(model.number_of_scales-1-i))/number_of_splits)
number_of_splits = min(model.min_splits, number_of_splits // model.exponential_scaling)
kl_multiplier.reverse()
self.kl_multiplier = torch.FloatTensor(kl_multiplier).unsqueeze(1)
self.kl_multiplier = self.kl_multiplier/torch.min(self.kl_multiplier)
self.kl_multiplier = self.kl_multiplier.to(device)
# for warm up
def warm_up_coeff(self):
if self.kl_warm_steps == 0:
return 1.0
return max(
min((self.current_step) / self.kl_warm_steps, 1.0),
self.min_kl_coeff)
def balance(self, kl_losses):
# during warm up you multiplie kl from different scales with different constants
if self.current_step < self.kl_warm_steps:
kl_all = torch.stack(kl_losses, dim=0)
kl_coeff_i = torch.abs(kl_all)
# average kl_loss for this group across batches
kl_coeff_i = torch.mean(kl_coeff_i, dim=1, keepdim=True) + 0.01
# TODO check if should divide or multiplie
kl_coeff_i = kl_coeff_i / self.kl_multiplier * torch.sum(kl_coeff_i)
kl_coeff_i = kl_coeff_i / torch.mean(kl_coeff_i, dim=0, keepdim=True)
return torch.sum(kl_all * kl_coeff_i.detach(), dim=0)
else:
kl_all = torch.stack(kl_losses, dim=0) # stacks splits kl
kl_all = torch.mean(kl_all, dim=1) # mean across batches
kl_all = torch.sum(kl_all) # sums everything up
return kl_all
def step(self):
self.current_step += 1