forked from adamian98/pulse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSphericalOptimizer.py
26 lines (22 loc) · 909 Bytes
/
SphericalOptimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import math
import torch
from torch.optim import Optimizer
# Spherical Optimizer Class
# Uses the first two dimensions as batch information
# Optimizes over the surface of a sphere using the initial radius throughout
#
# Example Usage:
# opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01)
class SphericalOptimizer(Optimizer):
def __init__(self, optimizer, params, **kwargs):
self.opt = optimizer(params, **kwargs)
self.params = params
with torch.no_grad():
self.radii = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params}
@torch.no_grad()
def step(self, closure=None):
loss = self.opt.step(closure)
for param in self.params:
param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt())
param.mul_(self.radii[param])
return loss