Skip to content

Maxfactor combines proven optimization techniques from several established algorithms, with implementation details specifically tuned for transformer architectures used in speech recognition.

Notifications You must be signed in to change notification settings

sine2pi/Maxfactor

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 

Repository files navigation

Parameter-specific learning rate scaling is crucial for transformer-based ASR models. Uniform optimization techniques often struggle with ASR - they're trying to apply the same rules to fundamentally different computational structures.

MaxFactor can be characterized as a memory-efficient adaptive optimizer that combines several innovative techniques:

  1. Factored Second Moments for Matrices

    • Like Adafactor, it uses row and column statistics rather than storing full matrices
    • Significantly reduces memory requirements compared to Adam-style optimizers
  2. Sign-Based Matrix Updates with Max-Pooling

    • For matrix parameters, takes sign of updates and scales by the maximum value along rows
    • This unique approach bridges sign-based methods and adaptive learning
  3. Dual Normalization Strategy

    • RMS-based clipping controls overall update magnitude
    • Optional infinity norm normalization ensures balanced updates across dimensions
  4. Adaptive Learning Rate Scheduling

    • Incorporates automatic learning rate decay based on step count
    • Parameter-specific scaling based on RMS values

Technical Strengths

  1. Memory Efficiency

    • O(n) storage for second moments rather than O(n²) like Adam
    • Especially beneficial for large language models with massive embedding matrices
  2. Numerical Stability

    • Multiple safeguards against exploding/vanishing gradients
    • Bounded beta values prevent extreme adaptation rates
  3. Flexibility

    • Works in both minimization and maximization modes
    • Configurable for different parameter shapes (vectors vs matrices)

MaxFactor essentially represents a hybrid approach that combines the memory efficiency of Adafactor, the adaptivity of Adam, and the robustness of sign-based methods, with its own unique max-pooling innovation for matrix parameters. It's particularly well-suited for training large models where memory constraints are significant.



MaxFactor family tree:

Adam
├── Adaptive learning rates 
└── EMA of second moments

Adafactor
├── Factorized second moments
└── Relative step sizing

SignSGD
└── Sign-based updates

LAMB/LARS
├── Layer-wise adaptivity
└── Gradient normalization

AdamW
└── Decoupled weight decay

Adamax
└── Infinity normalization

RMSprop
└── Root mean squared gradient scaling

Gradient Clipping
└── Max norm constraints

MaxFactor
└── Combines all above features and eventually FAM 
import torch

class MaxFactor(torch.optim.Optimizer):
    __version__ = "1.0"
    
    def __init__(self, params, lr=0.025, beta2_decay=-0.8, eps=(1e-10, 1e-4), d=1.0, 
                 weight_decay=0.025, gamma=0.99, max=False, min_lr=1e-7):
        
        print(f"Using MaxFactor optimizer v{self.__version__}")
        
        defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay, 
                        gamma=gamma, max=max, min_lr=min_lr)
        super().__init__(params=params, defaults=defaults)


    def get_lr(self):
        """Return current learning rates for all parameter groups."""
        param_specific_lrs = []
        
        for group in self.param_groups:
            group_lrs = []
            min_lr = group.get("min_lr", 1e-7)
            eps1, eps2 = group["eps"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]
                if "step" not in state:
                    continue
                step_float = state["step"].item()
                rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5)))
                
                param_norm = (p.norm() / (p.numel() ** 0.5 + 1e-12)).item()
                alpha = max(eps2, param_norm) * rho_t
                group_lrs.append(alpha)
            if group_lrs:
                param_specific_lrs.append(sum(group_lrs) / len(group_lrs))
            else:
                param_specific_lrs.append(group["lr"])
        return param_specific_lrs
    
    def get_last_lr(self):
        return self.get_lr()

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
            eps1, eps2 = group["eps"]
            min_lr = group.get("min_lr", 1e-7)
            
            for p in group["params"]:
                if p.grad is None:
                    continue
                    
                grad = p.grad
                if grad.dtype in {torch.float16, torch.bfloat16}:
                    grad = grad.float()

                state = self.state[p]
                if len(state) == 0:
                    state["step"] = torch.tensor(0.0, dtype=torch.float32)
                    if p.dim() > 1:
                        row_shape, col_shape = list(p.shape), list(p.shape)
                        row_shape[-1], col_shape[-2] = 1, 1
                        state["row_var"] = p.new_zeros(row_shape)
                        state["col_var"] = p.new_zeros(col_shape)
                    state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)

                row_vars.append(state.get("row_var", None))
                col_vars.append(state.get("col_var", None))
                v.append(state["v"])
                state_steps.append(state["step"])
                params_with_grad.append(p)
                grads.append(grad)

            for i, param in enumerate(params_with_grad):
                grad = grads[i]
                state = self.state[param]

                if group["max"]:
                    grad = -grad
                    
                step_t = state_steps[i]
                row_var, col_var, vi = row_vars[i], col_vars[i], v[i]

                if eps1 is None:
                    eps1 = torch.finfo(param.dtype).eps
                
                step_t += 1
                step_float = step_t.item()
                
                one_minus_beta2_t = min(0.999, max(0.001, step_float ** group["beta2_decay"]))
                
                rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5)))
                alpha = max(eps2, (param.norm() / (param.numel() ** 0.5 + 1e-12)).item()) * rho_t

                if group["weight_decay"] > 0:
                    param.mul_(1 - group["lr"] * group["weight_decay"])

                if grad.dim() > 1:
                    row_mean = torch.norm(grad, dim=-1, keepdim=True).square_()
                    row_mean.div_(grad.size(-1) + eps1)
                    
                    row_var.lerp_(row_mean, one_minus_beta2_t)
                    
                    col_mean = torch.norm(grad, dim=-2, keepdim=True).square_()
                    col_mean.div_(grad.size(-2) + eps1)
                    
                    col_var.lerp_(col_mean, one_minus_beta2_t)
                    
                    var_estimate = row_var @ col_var
                    max_row_var = row_var.max(dim=-2, keepdim=True)[0]  
                    var_estimate.div_(max_row_var.clamp_(min=eps1))
                else:
                    vi.mul_(group["gamma"]).add_(grad.square_(), alpha=1 - group["gamma"])
                    var_estimate = vi

                update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
                
                inf_norm = torch.norm(update, float('inf'))
                if inf_norm > 0:
                    update.div_(inf_norm.clamp_(min=eps1))
                
                denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
                
                if param.dim() > 1:
                    max_vals = update.abs().max(dim=-1, keepdim=True)[0]
                    param.add_(-alpha / denom * update.sign() * max_vals)
                else:
                    param.add_(-alpha / denom * update)
                
                state["step"] = step_t
                
        return loss
    

torch.optim
torch.optim.MaxFactor = MaxFactor

class LRMonitor:
    """Learning rate monitor that doesn't modify learning rates but reports them."""
    
    def __init__(self, optimizer):
        self.optimizer = optimizer
        self.lr_history = []
        self._last_lr = None
    
    def step(self):
        """Collect current learning rates without modifying them."""
        current_lrs = self.optimizer.get_lr()
        self.lr_history.append(current_lrs)
        self._last_lr = current_lrs
        return self
    
    def get_last_lr(self):
        """Return the last learning rates observed."""
        if self._last_lr is None:
            return [group['lr'] for group in self.optimizer.param_groups]
        return self._last_lr
    
    def plot_lr_history(self, figsize=(10, 6)):
        """Plot learning rate history if matplotlib is available."""
        try:
            import matplotlib.pyplot as plt
            plt.figure(figsize=figsize)
            
            for i, group_lrs in enumerate(zip(*self.lr_history)):
                plt.plot(group_lrs, label=f'Group {i}')
            
            plt.xlabel('Steps')
            plt.ylabel('Learning Rate')
            plt.title('MaxFactor Learning Rate History')
            plt.legend()
            plt.yscale('log')
            plt.grid(True, which='both', linestyle='--', alpha=0.5)
            plt.show()
        except ImportError:
            print("Matplotlib not available for plotting")
            print("LR History:", self.lr_history)


torch.optim.LRMonitor = LRMonitor

About

Maxfactor combines proven optimization techniques from several established algorithms, with implementation details specifically tuned for transformer architectures used in speech recognition.

Topics

Resources

Stars

Watchers

Forks