Skip to content

Commit

Permalink
Bugfixes and address comments by @theorashid from on PR pymc-devs#385
Browse files Browse the repository at this point in the history
Check for `jax` installation before any computation if `gradient_backend = 'jax'`
  • Loading branch information
jessegrabowski committed Feb 7, 2025
1 parent ea8a926 commit ec5dbef
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 55 deletions.
52 changes: 36 additions & 16 deletions pymc_extras/inference/find_map.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

from collections.abc import Callable
from importlib.util import find_spec
from typing import Literal, cast, get_args

import jax
import numpy as np
import pymc as pm
import pytensor
Expand All @@ -30,13 +30,29 @@
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
method_info = MINIMIZE_MODE_KWARGS[method].copy()

use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]

if use_hess and use_hessp:
_log.warning(
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
'Setting "use_hess" to False.'
)
use_hess = False

use_grad = use_grad if use_grad is not None else method_info["uses_grad"]

if use_hessp is not None and use_hess is None:
use_hess = not use_hessp

elif use_hess is not None and use_hessp is None:
use_hessp = not use_hess

elif use_hessp is None and use_hess is None:
use_hessp = method_info["uses_hessp"]
use_hess = method_info["uses_hess"]
if use_hessp and use_hess:
# If a method could use either hess or hessp, we default to using hessp
use_hess = False

return use_grad, use_hess, use_hessp


Expand All @@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
The nearest positive semi-definite matrix to the input matrix.
"""
C = (A + A.T) / 2
eigval, eigvec = np.linalg.eig(C)
eigval, eigvec = np.linalg.eigh(C)
eigval[eigval < 0] = 0

return eigvec @ np.diag(eigval) @ eigvec.T
Expand Down Expand Up @@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
return f_untransform(posterior_draws)


def _compile_jax_gradients(
def _compile_grad_and_hess_to_jax(
f_loss: Function, use_hess: bool, use_hessp: bool
) -> tuple[Callable | None, Callable | None]:
"""
Expand All @@ -122,6 +138,8 @@ def _compile_jax_gradients(
f_hessp: Callable | None
The compiled hessian-vector product function, or None if use_hessp is False.
"""
import jax

f_hess = None
f_hessp = None

Expand Down Expand Up @@ -152,7 +170,7 @@ def f_hess_jax(x):
return f_loss_and_grad, f_hess, f_hessp


def _compile_functions(
def _compile_functions_for_scipy_optimize(
loss: TensorVariable,
inputs: list[TensorVariable],
compute_grad: bool,
Expand All @@ -177,7 +195,7 @@ def _compile_functions(
compute_hessp: bool
Whether to compile a function that computes the Hessian-vector product of the loss function.
compile_kwargs: dict, optional
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
Additional keyword arguments to pass to the ``pm.compile`` function.
Returns
-------
Expand All @@ -193,19 +211,19 @@ def _compile_functions(
if compute_grad:
grads = pytensor.gradient.grad(loss, inputs)
grad = pt.concatenate([grad.ravel() for grad in grads])
f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
else:
f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
f_loss = pm.compile(inputs, loss, **compile_kwargs)
return [f_loss]

if compute_hess:
hess = pytensor.gradient.jacobian(grad, inputs)[0]
f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
f_hess = pm.compile(inputs, hess, **compile_kwargs)

if compute_hessp:
p = pt.tensor("p", shape=inputs[0].type.shape)
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)

return [f_loss_and_grad, f_hess, f_hessp]

Expand Down Expand Up @@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
gradient_backend: str, default "pytensor"
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
compile_kwargs:
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
Additional keyword arguments to pass to the ``pm.compile`` function.
Returns
-------
Expand All @@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
)

use_jax_gradients = (gradient_backend == "jax") and use_grad
if use_jax_gradients and not find_spec("jax"):
raise ImportError("JAX must be installed to use JAX gradients")

mode = compile_kwargs.get("mode", None)
if mode is None and use_jax_gradients:
Expand All @@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
compute_hess = use_hess and not use_jax_gradients
compute_hessp = use_hessp and not use_jax_gradients

funcs = _compile_functions(
funcs = _compile_functions_for_scipy_optimize(
loss=loss,
inputs=[flat_input],
compute_grad=compute_grad,
Expand All @@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(

if use_jax_gradients:
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)

return f_loss, f_hess, f_hessp

Expand Down
27 changes: 17 additions & 10 deletions pymc_extras/inference/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging

from functools import reduce
from importlib.util import find_spec
from itertools import product
from typing import Literal

Expand Down Expand Up @@ -231,7 +232,7 @@ def add_data_to_inferencedata(
return idata


def fit_mvn_to_MAP(
def fit_mvn_at_MAP(
optimized_point: dict[str, np.ndarray],
model: pm.Model | None = None,
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
Expand Down Expand Up @@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
inverse_hessian: np.ndarray
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
"""
if gradient_backend == "jax" and not find_spec("jax"):
raise ImportError("JAX must be installed to use JAX gradients")

model = pm.modelcontext(model)
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
frozen_model = freeze_dims_and_data(model)
Expand Down Expand Up @@ -344,8 +348,10 @@ def sample_laplace_posterior(
Parameters
----------
mu
H_inv
mu: RaveledVars
The MAP estimate of the model parameters.
H_inv: np.ndarray
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
model : Model
A PyMC model
chains : int
Expand Down Expand Up @@ -384,9 +390,7 @@ def sample_laplace_posterior(
constrained_rvs, replace={unconstrained_vector: batched_values}
)

f_constrain = pm.compile_pymc(
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
)
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
posterior_draws = f_constrain(posterior_draws)

else:
Expand Down Expand Up @@ -472,15 +476,17 @@ def fit_laplace(
and 1).
.. warning::
This argumnet should be considered highly experimental. It has not been verified if this method produces
This argument should be considered highly experimental. It has not been verified if this method produces
valid draws from the posterior. **Use at your own risk**.
gradient_backend: str, default "pytensor"
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
chains: int, default: 2
The number of sampling chains running in parallel.
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
compatible with the ArviZ library.
draws: int, default: 500
The number of samples to draw from the approximated posterior.
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
Expand Down Expand Up @@ -547,11 +553,12 @@ def fit_laplace(
**optimizer_kwargs,
)

mu, H_inv = fit_mvn_to_MAP(
mu, H_inv = fit_mvn_at_MAP(
optimized_point=optimized_point,
model=model,
on_bad_cov=on_bad_cov,
transform_samples=fit_in_unconstrained_space,
gradient_backend=gradient_backend,
zero_tol=zero_tol,
diag_jitter=diag_jitter,
compile_kwargs=compile_kwargs,
Expand Down
33 changes: 19 additions & 14 deletions tests/test_find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,28 @@ def compute_z(x):


@pytest.mark.parametrize(
"method, use_grad, use_hess",
"method, use_grad, use_hess, use_hessp",
[
("nelder-mead", False, False),
("powell", False, False),
("CG", True, False),
("BFGS", True, False),
("L-BFGS-B", True, False),
("TNC", True, False),
("SLSQP", True, False),
("dogleg", True, True),
("trust-ncg", True, True),
("trust-exact", True, True),
("trust-krylov", True, True),
("trust-constr", True, True),
("nelder-mead", False, False, False),
("powell", False, False, False),
("CG", True, False, False),
("BFGS", True, False, False),
("L-BFGS-B", True, False, False),
("TNC", True, False, False),
("SLSQP", True, False, False),
("dogleg", True, True, False),
("Newton-CG", True, True, False),
("Newton-CG", True, False, True),
("trust-ncg", True, True, False),
("trust-ncg", True, False, True),
("trust-exact", True, True, False),
("trust-krylov", True, True, False),
("trust-krylov", True, False, True),
("trust-constr", True, True, False),
],
)
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
extra_kwargs = {}
if method == "dogleg":
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
Expand All @@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend,
**extra_kwargs,
use_grad=use_grad,
use_hess=use_hess,
use_hessp=use_hessp,
progressbar=False,
gradient_backend=gradient_backend,
compile_kwargs={"mode": "JAX"},
Expand Down
Loading

0 comments on commit ec5dbef

Please sign in to comment.