Skip to content

Commit

Permalink
adding corrector back
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Jan 30, 2025
1 parent 384f36f commit 016c5a7
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
2 changes: 2 additions & 0 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def gradient(
)
else:
assert self.prior is not None, "Prior is required for iid methods."
# NOTE: Add here different methods for accumulating the score.
# TODO: Warn for FNPE -> Kinda needs a "corrector"
score_fn_iid = FNPEScoreFn(
self.score_estimator, self.prior, device=self.device
)
Expand Down
65 changes: 65 additions & 0 deletions sbi/samplers/score/correctors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
from abc import ABC, abstractmethod
from typing import Callable, Optional, Type

import torch
from torch import Tensor

from sbi.samplers.score.predictors import Predictor
Expand Down Expand Up @@ -63,3 +65,66 @@ def __call__(
@abstractmethod
def correct(self, theta: Tensor, t0: Tensor, t1: Optional[Tensor] = None) -> Tensor:
pass


@register_corrector("langevin")
class LangevinCorrector(Corrector):
def __init__(
self,
predictor: Predictor,
step_size: float = 1e-4,
num_steps: int = 5,
):
"""Basic Langevin corrector.
Ref: https://en.wikipedia.org/wiki/Langevin_dynamics
Args:
predictor: Associated predictor.
step_size (optional): Unadjusted Langevin dynamics are only valid for small
step sizes. Defaults to 1e-4.
num_steps (optional): Number of steps to correct. Defaults to 5.
"""
super().__init__(predictor)
self.step_size = step_size
self.std = math.sqrt(2 * self.step_size)
self.num_steps = num_steps

def correct(self, theta: Tensor, t0: Tensor, t1: Optional[Tensor] = None) -> Tensor:
# TODO: Why is this impacting performance
for _ in range(self.num_steps):
score = self.predictor.potential_fn.gradient(theta, t1)
eps = self.std * torch.randn_like(theta, device=self.device)
theta = theta + self.step_size * score + eps

return theta


@register_corrector("gibbs")
class GibbsCorrector(Corrector):
def __init__(self, predictor: Predictor, num_steps: int = 5):
"""(Pseudo) Gibbs sampling corrector. Iteratively adds back noise according to
the correct forward SDE, then removes noise using the predictor. Hence,
approximatly sampling form the joint distribution using Gibbs sampling (if the
two conditional distributions are compatible).
Args:
predictor (Predictor): Associated predictor.
num_steps (int, optional): Number of steps. Defaults to 5.
"""
super().__init__(predictor)
self.num_steps = num_steps

def noise(self, theta: Tensor, t0: Tensor, t1: Tensor) -> Tensor:
f = self.predictor.drift(theta, t0)
g = self.predictor.diffusion(theta, t0)
eps = torch.randn_like(theta, device=self.device)
dt = t1 - t0
dt_sqrt = torch.sqrt(dt)
return theta + f * dt + g * eps * dt_sqrt

def correct(self, theta: Tensor, t0: Tensor, t1: Tensor) -> Tensor:
for _ in range(self.num_steps):
theta = self.noise(theta, t0, t1)
theta = self.predictor(theta, t1, t0)
return theta
2 changes: 1 addition & 1 deletion tests/linearGaussian_npse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def simulator(theta):
# strict=True,
# match="Score accumulation*",
# )
@pytest.mark.parametrize("num_trials", [2, 10])
@pytest.mark.parametrize("num_trials", [2,])
def test_npse_iid_inference(num_trials):
"""Test whether NPSE infers well a simple example with available ground truth."""

Expand Down

0 comments on commit 016c5a7

Please sign in to comment.