Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower Precision CG Solver #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,14 +729,26 @@ def _solve(self, rhs: torch.Tensor, preconditioner: Callable, num_tridiag: int =
r"""
TODO
"""
return utils.linear_cg(
if rhs.dtype == torch.float16:
output = utils.linear_log_cg_re(
matmul_closure=self._matmul,
rhs=rhs,
n_tridiag=num_tridiag,
max_iter=settings.max_cg_iterations.value(),
max_tridiag_iter=settings.max_lanczos_quadrature_iterations.value(),
preconditioner=preconditioner,
)
else:
output = utils.linear_cg(
self._matmul,
rhs,
n_tridiag=num_tridiag,
max_iter=settings.max_cg_iterations.value(),
max_tridiag_iter=settings.max_lanczos_quadrature_iterations.value(),
preconditioner=preconditioner,
)
return output


def _solve_preconditioner(self) -> Callable:
r"""
Expand Down
124 changes: 124 additions & 0 deletions linear_operator/utils/linear_log_re_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch
from .. import settings


def _default_preconditioner(x):
return x.clone()


def linear_log_re_cg(
matmul_closure,
rhs,
max_iter,
initial_guess=None,
preconditioner=None,
tolerance=None,
eps=1e-10,
stop_updating_after=1e-10,
max_tridiag_iter=0,
n_tridiag=0,
):
if preconditioner is None:
preconditioner = _default_preconditioner
if tolerance is None:
tolerance = settings.cg_tolerance.value()
if initial_guess is None:
initial_guess = torch.zeros_like(rhs)
x0 = initial_guess
rhs_norm = rhs.norm(2, dim=-2, keepdim=True)
rhs_is_zero = rhs_norm.lt(eps)
rhs_norm = rhs_norm.masked_fill_(rhs_is_zero, 1)
rhs = rhs.div(rhs_norm)

state = initialize_log_re(matmul_closure, rhs, preconditioner, x0, max_iter)
for k in range(max_iter):
state = take_cg_step_log_re(state, matmul_closure, preconditioner)
if cond_fun(state, tolerance, max_iter):
break

x0 = state[0]
x0 = x0.mul(rhs_norm)
if n_tridiag > 0:
return x0, None
else:
return x0


def initialize_log_re(A, b, preconditioner, x0, max_iters):
r0 = b - A(x0)
z0 = preconditioner(r0)
p0 = z0
log_gamma0 = update_log_gamma_unclipped(r=r0, z=z0)
u_all = torch.zeros(size=(max_iters,) + b.shape, dtype=x0.dtype, device=x0.device)
return (x0, r0, log_gamma0, p0, u_all, torch.tensor(0, dtype=torch.int32))


def take_cg_step_log_re(state, A, preconditioner):
x0, r0, log_gamma0, p0, u_all, k = state
r_norm = torch.linalg.norm(r0, axis=-2, keepdim=True)
has_converged = r_norm < torch.tensor(1.0e-6, dtype=p0.dtype)
Ap0 = A(p0)

alpha = update_alpha_log_unclipped(log_gamma0, p0, Ap0, has_converged)
x1 = x0 + alpha * p0
r1 = r0 - alpha * Ap0
for i in range(k - 1):
dotprod = torch.sum(r1 * u_all[i], dim=-2, keepdim=True) * u_all[i]
r1 = torch.where(has_converged, r1, r1 - dotprod)
z1 = preconditioner(r1)
log_gamma1, beta = update_log_gamma_beta_unclipped(
r1, z1, log_gamma0, has_converged
)
u_all[k] = r1 / torch.sqrt(torch.exp(log_gamma1))
p1 = z1 + beta * p0

return (x1, r1, log_gamma1, p1, u_all, k + 1)


def update_alpha_log_unclipped(log_gamma, p, Ap, has_converged):
log_alpha_abs, sign = compute_robust_denom_unclipped(p, Ap)
log_denom = logsumexp(tensor=log_alpha_abs, dim=-2, mask=sign)
alpha = torch.exp(log_gamma - log_denom)
alpha = torch.where(has_converged, torch.zeros_like(alpha), alpha)
return alpha


def compute_robust_denom_unclipped(p, Ap):
p_abs = torch.clip(torch.abs(p), min=1.0e-8)
Ap_abs = torch.clip(torch.abs(Ap), min=1.0e-8)
sign = torch.sign(p) * torch.sign(Ap)
log_alpha_abs = torch.log(p_abs) + torch.log(Ap_abs)
return log_alpha_abs, sign


def update_log_gamma_beta_unclipped(r, z, log_gamma0, has_converged):
log_gamma1 = update_log_gamma_unclipped(r, z)
beta = torch.exp(log_gamma1 - log_gamma0)
beta = torch.where(has_converged, torch.zeros_like(beta), beta)
return log_gamma1, beta


def update_log_gamma_unclipped(r, z):
r_abs = torch.abs(r)
z_abs = torch.abs(z)
sign = torch.sign(r) * torch.sign(z)
log_gamma_abs = torch.log(r_abs) + torch.log(z_abs)
log_gamma = logsumexp(tensor=log_gamma_abs, dim=-2, mask=sign)
return log_gamma


def cond_fun(state, tolerance, max_iters):
_, r, *_, k = state
rs = torch.linalg.norm(r, axis=-2)
res_meet = torch.mean(rs) < tolerance
min_val = torch.minimum(
torch.tensor(10, dtype=torch.int32), torch.tensor(max_iters, dtype=torch.int32)
)
flag = (res_meet) & (k >= min_val) | (k > max_iters)
return flag


def logsumexp(tensor, dim=-1, mask=None):
max_entry = torch.max(tensor, dim, keepdim=True)[0]
summ = torch.sum((tensor - max_entry).exp() * mask, dim, keepdim=True)
return max_entry + summ.log()
58 changes: 58 additions & 0 deletions test/utils/test_linear_log_re_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3

import os
import random
import unittest

import torch

from linear_operator.utils.linear_log_re_cg import linear_log_re_cg as linear_cg


class TestLinearCG(unittest.TestCase):
def setUp(self):
if os.getenv("UNLOCK_SEED") is None or os.getenv("UNLOCK_SEED").lower() == "false":
self.rng_state = torch.get_rng_state()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
random.seed(0)

def tearDown(self):
if hasattr(self, "rng_state"):
torch.set_rng_state(self.rng_state)

def test_cg(self):
size = 100
matrix = torch.randn(size, size, dtype=torch.float64)
matrix = matrix.matmul(matrix.mT)
matrix.div_(matrix.norm())
matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))

rhs = torch.randn(size, 50, dtype=torch.float64)
solves = linear_cg(matrix.matmul, rhs=rhs, max_iter=size)

# Check cg
matrix_chol = torch.linalg.cholesky(matrix)
actual = torch.cholesky_solve(rhs, matrix_chol)
self.assertTrue(torch.allclose(solves, actual, atol=1e-3, rtol=1e-4))

def test_batch_cg(self):
batch = 5
size = 100
matrix = torch.randn(batch, size, size, dtype=torch.float64)
matrix = matrix.matmul(matrix.mT)
matrix.div_(matrix.norm())
matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))

rhs = torch.randn(batch, size, 50, dtype=torch.float64)
solves = linear_cg(matrix.matmul, rhs=rhs, max_iter=size)

# Check cg
matrix_chol = torch.linalg.cholesky(matrix)
actual = torch.cholesky_solve(rhs, matrix_chol)
self.assertTrue(torch.allclose(solves, actual, atol=1e-3, rtol=1e-4))


if __name__ == "__main__":
unittest.main()