Skip to content

Commit

Permalink
adding the object oriented interface
Browse files Browse the repository at this point in the history
  • Loading branch information
coursekevin committed Nov 6, 2023
1 parent e165da8 commit 0712183
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 16 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2019 coursekevin
Copyright (c) 2023 Kevin Course

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
48 changes: 38 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
## What is GradOpTorch?

GradOpTorch is a suite of classical gradient-based optimization tools for
PyTorch. The toolkit includes conjugate gradients, BFGS, and some
PyTorch. The toolkit includes conjugate gradients, BFGS, and some
methods for line-search.

## Why not [torch.optim](https://pytorch.org/docs/stable/optim.html)?

Not every problem is high-dimensional with noisy gradients.
For such problems, classical optimization techniques
can be more appropriate.
Not every problem is high-dimensional, nonlinear, with noisy gradients.
For such problems, classical optimization techniques
can be more efficient.

## Installation

Expand All @@ -31,6 +31,40 @@ pip install gradoptorch

## Usage

There are two primary interfaces for making use of the library.

1. The standard PyTorch object oriented interface:

```python
from gradoptorch import optimize_module
from torch import nn

class MyModel(nn.Module):
...

model = MyModule()

def loss_fn(model):
...

hist = optimize_module(model, loss_fn, opt_method="bfgs", ls_method="back_tracking")
```

2. The functional interface:

```python
from gradoptorch import optimizer

def f(x):
...

x_guess = ...

x_opt, hist = optimizer(f, x_guess, opt_method="conj_grad_pr", ls_method="quad_search")
```

Newton's method is only available in the functional interface

### Included optimizers:

'grad_exact' : exact gradient optimization
Expand All @@ -44,9 +78,3 @@ pip install gradoptorch
'back_tracking' : backing tracking based line-search
'quad_search' : quadratic line-search
'constant' : no line search, constant step size used

## Setup

```bash
pip install git+https://github.com/coursekevin/gradoptorch.git
```
7 changes: 3 additions & 4 deletions examples/functional_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ def main():
x_opt, hist = optimizer(f, x, opt_method=method, ls_method="quad_search")
histories.append(hist)

X, Y = torch.meshgrid(
torch.linspace(-2.5, 2.5, 100), torch.linspace(-1.5, 3.5, 100), indexing="ij"
)

# ---------------------------------------------------------------------------------
# making some plots
_, axs = plt.subplots(1, 2)
Expand All @@ -44,6 +40,9 @@ def main():
ax2.plot(x_hist[:, 0], x_hist[:, 1], "x-")
ax1.legend(opt_method)

X, Y = torch.meshgrid(
torch.linspace(-2.5, 2.5, 100), torch.linspace(-1.5, 3.5, 100), indexing="ij"
)
ax2.contourf(X, Y, f(torch.stack([X, Y], dim=0)), 50)

plt.show()
Expand Down
70 changes: 70 additions & 0 deletions examples/module_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from torch import nn
from gradoptorch import optimize_module

import matplotlib.pyplot as plt # type: ignore

dim = 2

torch.set_default_dtype(torch.float64)
torch.manual_seed(42)


class SomeModule(nn.Module):
def __init__(self):
super().__init__()
self.x = nn.Parameter(torch.tensor(-0.5))
self.y = nn.Parameter(torch.tensor(3.0))


def loss_fn(model):
a = 1.0
b = 100.0
return (a - model.x).pow(2) + b * (model.y - model.x.pow(2)).pow(2)


def main():
# ---------------------------------------------------------------------------------
# optimizing using all of the different methods
opt_method = ["conj_grad_pr", "conj_grad_fr", "grad_exact", "bfgs"]
histories = []
for method in opt_method:
# using quadratic line search
model = SomeModule()
hist = optimize_module(
model, loss_fn, opt_method=method, ls_method="quad_search"
)
for n, p in model.named_parameters():
print(n, p)
histories.append(hist)

# ---------------------------------------------------------------------------------
# making some plots
_, axs = plt.subplots(1, 2)

ax1 = axs[0]
ax2 = axs[1]

for hist in histories:
ax1.plot(torch.tensor(hist.f_hist).log10().detach())
x_hist = torch.stack(hist.x_hist, dim=0).detach()
ax2.plot(x_hist[:, 0], x_hist[:, 1], "x-")
ax1.legend(opt_method)

X, Y = torch.meshgrid(
torch.linspace(-2.5, 2.5, 100), torch.linspace(-1.5, 3.5, 100), indexing="ij"
)

# Define test objective function
def f(x):
a = 1.0
b = 100.0
return (a - x[0]).pow(2) + b * (x[1] - x[0].pow(2)).pow(2)

ax2.contourf(X, Y, f(torch.stack([X, Y], dim=0)), 50)

plt.show()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions gradoptorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .gradoptorch import optimizer, default_opt_settings, default_ls_settings
from .module import optimize_module
2 changes: 1 addition & 1 deletion gradoptorch/gradoptorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def grad_fn(x: Float[Tensor, " d"]) -> Float[Tensor, "d 1"]:
def optimizer(
f: Callable[[Float[Tensor, " d"]], Float[Tensor, ""]],
x_guess: Float[Tensor, " d"],
g: Optional[Callable[[Float[Tensor, " d"]], Float[Tensor, ""]]] = None,
g: Optional[Callable[[Float[Tensor, " d"]], Float[Tensor, " d"]]] = None,
opt_method: str = "conj_grad_pr",
opt_params: dict[str, Any] = default_opt_settings,
ls_method: str = "back_tracking",
Expand Down
80 changes: 80 additions & 0 deletions gradoptorch/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Callable, Any

import torch
from torch import nn, Tensor
from jaxtyping import Float

from .gradoptorch import optimizer, OptimLog, default_opt_settings, default_ls_settings


def update_model_params(model: nn.Module, new_params: Tensor) -> None:
pointer = 0
for param in model.parameters():
num_param = param.numel()
param.data = new_params[pointer : pointer + num_param].view_as(param).data
pointer += num_param
param.requires_grad = True


def optimize_module(
model: nn.Module,
f: Callable[[nn.Module], Float[Tensor, ""]],
opt_method: str = "conj_grad_pr",
opt_params: dict[str, Any] = default_opt_settings,
ls_method: str = "back_tracking",
ls_params: dict[str, Any] = default_ls_settings,
) -> OptimLog:
"""
Optimizes the parameters of a given nn.Module using classical optimizer.
See the `optimizer` function for more details on the optimization methods
available.
INPUTS:
model < nn.Module > : The torch.nn model to be optimized.
f < Callable[[nn.Module], Float[Tensor, ""]] > : The loss function to be minimized
opt_method < str > : The optimization method to be used.
opt_params < dict[str, Any] > : The parameters to be used for the optimization method.
ls_method < str > The line search method to be used.
ls_params < dict[str, Any] > : The parameters to be used for the line search method.
OptimLog: the log of the optimization process.
"""
if opt_method == "newton_exact":
raise NotImplementedError(
"Exact Newton's method is not implemented for optimize_module."
)

# Flatten the model parameters and use them as an initial guess if not provided
params = torch.cat([param.view(-1) for param in model.parameters()])

def f_wrapper(params: Float[Tensor, " d"]) -> Float[Tensor, ""]:
update_model_params(model, params)
return f(model)

def grad_wrapper(params: Float[Tensor, " d"]) -> Float[Tensor, " d"]:
update_model_params(model, params)
model.zero_grad()
with torch.enable_grad():
loss = f(model)
loss.backward()
return torch.cat(
[
param.grad.view(-1)
if param.grad is not None
else torch.zeros_like(param).view(-1)
for param in model.parameters()
]
)

final_params, hist = optimizer(
f=f_wrapper,
x_guess=params,
g=grad_wrapper,
opt_method=opt_method,
opt_params=opt_params,
ls_method=ls_method,
ls_params=ls_params,
)
update_model_params(model, final_params)
return hist

0 comments on commit 0712183

Please sign in to comment.