Skip to content

Commit

Permalink
adding untesed GAUSS
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Jan 30, 2025
1 parent 016c5a7 commit 9152d28
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 0 deletions.
172 changes: 172 additions & 0 deletions sbi/inference/potentials/score_fn_iid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.utils.torchutils import ensure_theta_batched
from sbi.inference.potentials.score_utils import (
add_diag_or_dense,
denoise,
marginalize,
mv_diag_or_dense,
solve_diag_or_dense,
)


class ScoreFnIID:
Expand Down Expand Up @@ -132,3 +139,168 @@ def __call__(
score = (1 - N) * prior_score + base_score.sum(-2, keepdim=True)

return score


class AbstractGaussCorrectedScoreFn(ScoreFnIID):
def __init__(
self,
score_estimator: ConditionalScoreEstimator,
prior: Distribution,
) -> None:
r"""Initializes the AbstractGaussCorrectedScoreFn class.
Args:
score_estimator: The neural network modelling the score.
prior: The prior distribution.
"""
super().__init__(score_estimator, prior)

@abstractmethod
def posterior_precision_est_fn(self, x_o: Tensor) -> Tensor:
r"""Abstract method to estimate the posterior precision.
Args:
x_o: Observed data.
Returns:
Estimated posterior precision.
"""
pass

def denoising_prior(self, m: Tensor, std: Tensor, theta: Tensor) -> Distribution:
r"""Denoise the prior distribution.
Args:
m: Mean tensor.
std: Standard deviation tensor.
theta: Parameters tensor.
Returns:
Denoised prior distribution.
"""
return denoise(self.prior, m, std, theta)

def marginal_prior(self, a: Tensor, theta: Tensor) -> Distribution:
r"""Compute the marginal prior distribution.
Args:
a: Auxiliary variable tensor.
theta: Parameters tensor.
Returns:
Marginal prior distribution.
"""
m = self.score_estimator.mean_t_fn(a)
std = self.score_estimator.std_t_fn(a)
return marginalize(self.prior, m, std)

def marginal_posterior_precision_est_fn(
self, a: Tensor, theta: Tensor, x_o: Tensor, N: int
) -> Tensor:
r"""Estimates the marginal posterior precision.
Args:
a: Auxiliary variable tensor.
theta: Parameters tensor.
x_o: Observed data.
N: Number of samples.
Returns:
Estimated marginal posterior precision.
"""
precisions_posteriors = self.posterior_precision_est_fn(x_o)
precisions_posteriors = torch.atleast_2d(precisions_posteriors)

# If one constant precision is given, tile it
if precisions_posteriors.shape[0] < N:
precisions_posteriors = precisions_posteriors.repeat(N, 1)

# Denoising posterior via Bayes rule
m = self.score_estimator.mean_t_fn(a)
std = self.score_estimator.std_t_fn(a)

if precisions_posteriors.ndim == 3:
I = torch.eye(precisions_posteriors.shape[-1])
else:
I = torch.ones_like(precisions_posteriors)

marginal_precisions = m**2 / std**2 * I + precisions_posteriors
return marginal_precisions

def marginal_prior_score_fn(self, a: Tensor, theta: Tensor) -> Tensor:
r"""Computes the score of the marginal prior distribution.
Args:
a: Auxiliary variable tensor.
theta: Parameters tensor.
Returns:
Marginal prior score.
"""
p = self.marginal_prior(a, theta)
log_p = p.log_prob(theta)
return torch.autograd.grad(
log_p,
theta,
grad_outputs=torch.ones_like(log_p),
create_graph=True,
)[0]

def marginal_denoising_prior_precision_fn(self, a: Tensor, theta: Tensor) -> Tensor:
r"""Computes the precision of the marginal denoising prior.
Args:
a: Auxiliary variable tensor.
theta: Parameters tensor.
Returns:
Marginal denoising prior precision.
"""
m = self.score_estimator.mean_t_fn(a)
std = self.score_estimator.std_t_fn(a)
p_denoise = self.denoising_prior(m, std, theta)
return 1 / p_denoise.variance

def __call__(self, a: Tensor, theta: Tensor, x_o: Tensor, **kwargs) -> Tensor:
r"""Returns the corrected score function.
Args:
a: Auxiliary variable tensor.
theta: Parameters tensor.
x_o: Observed data.
Returns:
Corrected score function.
"""
base_score = self.score_estimator(a, theta, x_o, **kwargs)
prior_score = self.marginal_prior_score_fn(a, theta)
N = base_score.shape[0]

# Marginal prior precision
prior_precision = self.marginal_denoising_prior_precision_fn(a, theta)
# Marginal posterior variance estimates
posterior_precisions = self.marginal_posterior_precision_est_fn(
a, theta, x_o, N
)

# Total precision
term1 = (1 - N) * prior_precision
term2 = torch.sum(posterior_precisions, dim=0)
Lam = add_diag_or_dense(term1, term2)

# Weighted scores
weighted_prior_score = mv_diag_or_dense(prior_precision, prior_score)
weighted_posterior_scores = torch.stack([
mv_diag_or_dense(p, base_score[i])
for i, p in enumerate(posterior_precisions)
])

# Accumulate the scores
score = (1 - N) * weighted_prior_score + torch.sum(
weighted_posterior_scores, dim=0
)

# Solve the linear system
score = solve_diag_or_dense(Lam, score)

return score
58 changes: 58 additions & 0 deletions sbi/inference/potentials/score_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torch import Tensor
import torch
from torch.distributions import Distribution, Independent, Normal

# Automatic denoising -----------------------------------------------------
Expand Down Expand Up @@ -130,3 +131,60 @@ def marginalize_gaussian(p: Normal, m: Tensor, s: Tensor) -> Normal:
std = var**0.5

return Normal(mu, std)


def mv_diag_or_dense(A_diag_or_dense: Tensor, b: Tensor) -> Tensor:
"""Dot product of a diagonal matrix and a dense matrix.
Args:
A_diag_or_dense (Tensor): Diagonal matrix.
b (Tensor): Dense matrix.
Returns:
Tensor: Dot product.
"""
if A_diag_or_dense.ndim == 1:
return A_diag_or_dense * b
else:
return torch.matmul(A_diag_or_dense, b)


def solve_diag_or_dense(A_diag_or_dense: Tensor, b: Tensor) -> Tensor:
"""Solve a linear system with a diagonal matrix or a dense matrix.
Args:
A_diag_or_dense (Tensor): Diagonal matrix or dense matrix.
b (Tensor): Dense matrix.
Returns:
Tensor: Solution to the linear system.
"""
if A_diag_or_dense.ndim == 1:
return b / A_diag_or_dense
else:
return torch.linalg.solve(A_diag_or_dense, b)


def add_diag_or_dense(A_diag_or_dense: Tensor, B_diag_or_dense: Tensor) -> Tensor:
"""Add two diagonal matrices or two dense matrices.
Args:
A_diag_or_dense (Tensor): Diagonal matrix or dense matrix.
B_diag_or_dense (Tensor): Diagonal matrix or dense matrix.
Returns:
Tensor: Sum of the matrices.
"""
if (
A_diag_or_dense.ndim == 1
and B_diag_or_dense.ndim == 1
or A_diag_or_dense.ndim == 2
and B_diag_or_dense.ndim == 2
):
return A_diag_or_dense + B_diag_or_dense
elif A_diag_or_dense.ndim == 2 and B_diag_or_dense.ndim == 1:
return A_diag_or_dense + torch.diag(B_diag_or_dense)
elif A_diag_or_dense.ndim == 1 and B_diag_or_dense.ndim == 2:
return torch.diag(A_diag_or_dense) + B_diag_or_dense
else:
raise ValueError("Incompatible dimensions for addition")

0 comments on commit 9152d28

Please sign in to comment.