Parameter-specific learning rate scaling is crucial for transformer-based ASR models. Uniform optimization techniques often struggle with ASR - they're trying to apply the same rules to fundamentally different computational structures.
MaxFactor can be characterized as a memory-efficient adaptive optimizer that combines several innovative techniques:
-
Factored Second Moments for Matrices
- Like Adafactor, it uses row and column statistics rather than storing full matrices
- Significantly reduces memory requirements compared to Adam-style optimizers
-
Sign-Based Matrix Updates with Max-Pooling
- For matrix parameters, takes sign of updates and scales by the maximum value along rows
- This unique approach bridges sign-based methods and adaptive learning
-
Dual Normalization Strategy
- RMS-based clipping controls overall update magnitude
- Optional infinity norm normalization ensures balanced updates across dimensions
-
Adaptive Learning Rate Scheduling
- Incorporates automatic learning rate decay based on step count
- Parameter-specific scaling based on RMS values
-
Memory Efficiency
- O(n) storage for second moments rather than O(n²) like Adam
- Especially beneficial for large language models with massive embedding matrices
-
Numerical Stability
- Multiple safeguards against exploding/vanishing gradients
- Bounded beta values prevent extreme adaptation rates
-
Flexibility
- Works in both minimization and maximization modes
- Configurable for different parameter shapes (vectors vs matrices)
MaxFactor essentially represents a hybrid approach that combines the memory efficiency of Adafactor, the adaptivity of Adam, and the robustness of sign-based methods, with its own unique max-pooling innovation for matrix parameters. It's particularly well-suited for training large models where memory constraints are significant.
MaxFactor family tree:
Adam
├── Adaptive learning rates
└── EMA of second moments
Adafactor
├── Factorized second moments
└── Relative step sizing
SignSGD
└── Sign-based updates
LAMB/LARS
├── Layer-wise adaptivity
└── Gradient normalization
AdamW
└── Decoupled weight decay
Adamax
└── Infinity normalization
RMSprop
└── Root mean squared gradient scaling
Gradient Clipping
└── Max norm constraints
MaxFactor
└── Combines all above features and eventually FAM
import torch
class MaxFactor(torch.optim.Optimizer):
__version__ = "1.0"
def __init__(self, params, lr=0.025, beta2_decay=-0.8, eps=(1e-10, 1e-4), d=1.0,
weight_decay=0.025, gamma=0.99, max=False, min_lr=1e-7):
print(f"Using MaxFactor optimizer v{self.__version__}")
defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
gamma=gamma, max=max, min_lr=min_lr)
super().__init__(params=params, defaults=defaults)
def get_lr(self):
"""Return current learning rates for all parameter groups."""
param_specific_lrs = []
for group in self.param_groups:
group_lrs = []
min_lr = group.get("min_lr", 1e-7)
eps1, eps2 = group["eps"]
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]
if "step" not in state:
continue
step_float = state["step"].item()
rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5)))
param_norm = (p.norm() / (p.numel() ** 0.5 + 1e-12)).item()
alpha = max(eps2, param_norm) * rho_t
group_lrs.append(alpha)
if group_lrs:
param_specific_lrs.append(sum(group_lrs) / len(group_lrs))
else:
param_specific_lrs.append(group["lr"])
return param_specific_lrs
def get_last_lr(self):
return self.get_lr()
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
eps1, eps2 = group["eps"]
min_lr = group.get("min_lr", 1e-7)
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
state = self.state[p]
if len(state) == 0:
state["step"] = torch.tensor(0.0, dtype=torch.float32)
if p.dim() > 1:
row_shape, col_shape = list(p.shape), list(p.shape)
row_shape[-1], col_shape[-2] = 1, 1
state["row_var"] = p.new_zeros(row_shape)
state["col_var"] = p.new_zeros(col_shape)
state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
row_vars.append(state.get("row_var", None))
col_vars.append(state.get("col_var", None))
v.append(state["v"])
state_steps.append(state["step"])
params_with_grad.append(p)
grads.append(grad)
for i, param in enumerate(params_with_grad):
grad = grads[i]
state = self.state[param]
if group["max"]:
grad = -grad
step_t = state_steps[i]
row_var, col_var, vi = row_vars[i], col_vars[i], v[i]
if eps1 is None:
eps1 = torch.finfo(param.dtype).eps
step_t += 1
step_float = step_t.item()
one_minus_beta2_t = min(0.999, max(0.001, step_float ** group["beta2_decay"]))
rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5)))
alpha = max(eps2, (param.norm() / (param.numel() ** 0.5 + 1e-12)).item()) * rho_t
if group["weight_decay"] > 0:
param.mul_(1 - group["lr"] * group["weight_decay"])
if grad.dim() > 1:
row_mean = torch.norm(grad, dim=-1, keepdim=True).square_()
row_mean.div_(grad.size(-1) + eps1)
row_var.lerp_(row_mean, one_minus_beta2_t)
col_mean = torch.norm(grad, dim=-2, keepdim=True).square_()
col_mean.div_(grad.size(-2) + eps1)
col_var.lerp_(col_mean, one_minus_beta2_t)
var_estimate = row_var @ col_var
max_row_var = row_var.max(dim=-2, keepdim=True)[0]
var_estimate.div_(max_row_var.clamp_(min=eps1))
else:
vi.mul_(group["gamma"]).add_(grad.square_(), alpha=1 - group["gamma"])
var_estimate = vi
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
inf_norm = torch.norm(update, float('inf'))
if inf_norm > 0:
update.div_(inf_norm.clamp_(min=eps1))
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
if param.dim() > 1:
max_vals = update.abs().max(dim=-1, keepdim=True)[0]
param.add_(-alpha / denom * update.sign() * max_vals)
else:
param.add_(-alpha / denom * update)
state["step"] = step_t
return loss
torch.optim
torch.optim.MaxFactor = MaxFactor
class LRMonitor:
"""Learning rate monitor that doesn't modify learning rates but reports them."""
def __init__(self, optimizer):
self.optimizer = optimizer
self.lr_history = []
self._last_lr = None
def step(self):
"""Collect current learning rates without modifying them."""
current_lrs = self.optimizer.get_lr()
self.lr_history.append(current_lrs)
self._last_lr = current_lrs
return self
def get_last_lr(self):
"""Return the last learning rates observed."""
if self._last_lr is None:
return [group['lr'] for group in self.optimizer.param_groups]
return self._last_lr
def plot_lr_history(self, figsize=(10, 6)):
"""Plot learning rate history if matplotlib is available."""
try:
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
for i, group_lrs in enumerate(zip(*self.lr_history)):
plt.plot(group_lrs, label=f'Group {i}')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('MaxFactor Learning Rate History')
plt.legend()
plt.yscale('log')
plt.grid(True, which='both', linestyle='--', alpha=0.5)
plt.show()
except ImportError:
print("Matplotlib not available for plotting")
print("LR History:", self.lr_history)
torch.optim.LRMonitor = LRMonitor