-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlr_schedule.py
28 lines (20 loc) · 946 Bytes
/
lr_schedule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright(c) Microsoft Corporation.
# Licensed under the MIT license.
def inv_lr_scheduler(optimizer, iter_num, gamma, power, lr=0.001, weight_decay=0.0005):
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
lr = lr * (1 + gamma * iter_num) ** (-power)
i=0
for param_group in optimizer.param_groups:
param_group['lr'] = lr * param_group['lr_mult']
param_group['weight_decay'] = weight_decay * param_group['decay_mult']
i+=1
return optimizer
def inv_lr_scheduler_mmd(param_lr, optimizer, iter_num, gamma, power, init_lr=0.001):
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
lr = init_lr * (1 + gamma * iter_num) ** (-power)
i=0
for param_group in optimizer.param_groups:
param_group['lr'] = lr * param_lr[i]
i+=1
return optimizer
schedule_dict = {"inv": inv_lr_scheduler, "inv_mmd": inv_lr_scheduler_mmd}