diff --git a/pyproject.toml b/pyproject.toml
index 71c270459..cf180362d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -73,6 +73,7 @@ dev = [
"pytest",
"pytest-cov",
"pytest-xdist",
+ "pytest-harvest",
"torchtestcase",
]
@@ -131,7 +132,8 @@ testpaths = ["tests"]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"gpu: marks tests that require a gpu (deselect with '-m \"not gpu\"')",
- "mcmc: marks tests that require MCMC sampling (deselect with '-m \"not mcmc\"')"
+ "mcmc: marks tests that require MCMC sampling (deselect with '-m \"not mcmc\"')",
+ "benchmark: marks test that are soley for benchmarking purposes"
]
xfail_strict = true
diff --git a/tests/bm_test.py b/tests/bm_test.py
new file mode 100644
index 000000000..e38fd9659
--- /dev/null
+++ b/tests/bm_test.py
@@ -0,0 +1,272 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+
+import pytest
+import torch
+from pytest_harvest import ResultsBag
+
+from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
+from sbi.inference.posteriors.base_posterior import NeuralPosterior
+from sbi.inference.trainers.npe import NPE_C
+from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C
+from sbi.utils.metrics import c2st
+
+from .mini_sbibm import get_task
+from .mini_sbibm.base_task import Task
+
+# Global settings
+SEED = 0
+TASKS = ["two_moons", "linear_mvg_2d", "gaussian_linear", "slcp"]
+NUM_SIMULATIONS = 2000
+NUM_EVALUATION_OBS = 3 # Currently only 3 observation tested for speed
+NUM_ROUNDS_SEQUENTIAL = 2
+NUM_EVALUATION_OBS_SEQ = 1
+TRAIN_KWARGS = {}
+
+# Density estimators to test
+DENSITY_ESTIMATORS = ["mdn", "made", "maf", "nsf", "maf_rqs"] # "Kinda exhaustive"
+CLASSIFIERS = ["mlp", "resnet"]
+NNS = ["mlp", "resnet"]
+SCORE_ESTIMATORS = ["mlp", "ada_mlp"]
+
+# Benchmarking method groups i.e. what to run for different --bm-mode
+METHOD_GROUPS = {
+ "none": [NPE, NRE, NLE, FMPE, NPSE],
+ "npe": [NPE],
+ "nle": [NLE],
+ "nre": [NRE_A, NRE_B, NRE_C, BNRE],
+ "fmpe": [FMPE],
+ "npse": [NPSE],
+ "snpe": [NPE_C], # NPE_B not implemented, NPE_A need Gaussian prior
+ "snle": [NLE],
+ "snre": [NRE_A, NRE_B, NRE_C, BNRE],
+}
+METHOD_PARAMS = {
+ "none": [{}],
+ "npe": [{"density_estimator": de} for de in DENSITY_ESTIMATORS],
+ "nle": [{"density_estimator": de} for de in ["maf", "nsf"]],
+ "nre": [{"classifier": cl} for cl in CLASSIFIERS],
+ "fmpe": [{"density_estimator": nn} for nn in NNS],
+ "npse": [
+ {"score_estimator": nn, "sde_type": sde}
+ for nn in SCORE_ESTIMATORS
+ for sde in ["ve", "vp"]
+ ],
+ "snpe": [{}],
+ "snle": [{}],
+ "snre": [{}],
+}
+
+
+@pytest.fixture
+def method_list(benchmark_mode: str) -> list:
+ """
+ Fixture to get the list of methods based on the benchmark mode.
+
+ Args:
+ benchmark_mode (str): The benchmark mode.
+
+ Returns:
+ list: List of methods for the given benchmark mode.
+ """
+ name = str(benchmark_mode).lower()
+ if name not in METHOD_GROUPS:
+ raise ValueError(f"Benchmark mode '{benchmark_mode}' is not supported.")
+ return METHOD_GROUPS[name]
+
+
+@pytest.fixture
+def kwargs_list(benchmark_mode: str) -> list:
+ """
+ Fixture to get the list of kwargs based on the benchmark mode.
+
+ Args:
+ benchmark_mode (str): The benchmark mode.
+
+ Returns:
+ list: List of kwargs for the given benchmark mode.
+ """
+ name = str(benchmark_mode).lower()
+ if name not in METHOD_PARAMS:
+ raise ValueError(f"Benchmark mode '{benchmark_mode}' is not supported.")
+ return METHOD_PARAMS[name]
+
+
+# Use pytest.mark.parametrize dynamically
+# Generates a list of methods to test based on the benchmark mode
+def pytest_generate_tests(metafunc):
+ """
+ Dynamically generates a list of methods to test based on the benchmark mode.
+
+ Args:
+ metafunc: The metafunc object from pytest.
+ """
+ if "inference_class" in metafunc.fixturenames:
+ method_list = metafunc.config.getoption("--bm-mode")
+ name = str(method_list).lower()
+ method_group = METHOD_GROUPS.get(name, [])
+ metafunc.parametrize("inference_class", method_group)
+ if "extra_kwargs" in metafunc.fixturenames:
+ kwargs_list = metafunc.config.getoption("--bm-mode")
+ name = str(kwargs_list).lower()
+ kwargs_group = METHOD_PARAMS.get(name, [])
+ metafunc.parametrize("extra_kwargs", kwargs_group)
+
+
+def standard_eval_c2st_loop(posterior: NeuralPosterior, task: Task) -> float:
+ """
+ Evaluates the C2ST metric for the given posterior and task.
+
+ Args:
+ posterior: The posterior distribution.
+ task: The task object.
+
+ Returns:
+ float: The mean C2ST value.
+ """
+ c2st_scores = []
+ for i in range(1, NUM_EVALUATION_OBS + 1):
+ c2st_val = eval_c2st(posterior, task, i)
+ c2st_scores.append(c2st_val)
+
+ mean_c2st = sum(c2st_scores) / len(c2st_scores)
+ # Convert to float rounded to 3 decimal places
+ mean_c2st = float(f"{mean_c2st:.3f}")
+ return mean_c2st
+
+
+def eval_c2st(
+ posterior: NeuralPosterior,
+ task: Task,
+ idx_observation: int,
+ num_samples: int = 1000,
+) -> float:
+ """
+ Evaluates the C2ST metric for a specific observation.
+
+ Args:
+ posterior: The posterior distribution.
+ task: The task object.
+ i (int): The observation index.
+
+ Returns:
+ float: The C2ST value.
+ """
+ x_o = task.get_observation(idx_observation)
+ posterior_samples = task.get_reference_posterior_samples(idx_observation)
+ approx_posterior_samples = posterior.sample((num_samples,), x=x_o)
+ if isinstance(approx_posterior_samples, tuple):
+ approx_posterior_samples = approx_posterior_samples[0]
+ assert posterior_samples.shape[0] >= num_samples, "Not enough reference samples"
+ c2st_val = c2st(posterior_samples[:num_samples], approx_posterior_samples)
+ return float(c2st_val)
+
+
+def train_and_eval_amortized_inference(
+ inference_class, task_name: str, extra_kwargs: dict, results_bag: ResultsBag
+) -> None:
+ """
+ Performs amortized inference evaluation.
+
+ Args:
+ method: The inference method.
+ task_name: The name of the task.
+ extra_kwargs: Additional keyword arguments for the method.
+ results_bag: The results bag to store evaluation results. Subclass of dict, but
+ allows item assignment with dot notation.
+ """
+ torch.manual_seed(SEED)
+ task = get_task(task_name)
+ thetas, xs = task.get_data(NUM_SIMULATIONS)
+ prior = task.get_prior()
+
+ inference = inference_class(prior, **extra_kwargs)
+ _ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)
+
+ posterior = inference.build_posterior()
+
+ mean_c2st = standard_eval_c2st_loop(posterior, task)
+
+ # Cache results
+ results_bag.metric = mean_c2st
+ results_bag.num_simulations = NUM_SIMULATIONS
+ results_bag.task_name = task_name
+ results_bag.method = inference_class.__name__ + str(extra_kwargs)
+
+
+def train_and_eval_sequential_inference(
+ inference_class, task_name: str, extra_kwargs: dict, results_bag: ResultsBag
+) -> None:
+ """
+ Performs sequential inference evaluation.
+
+ Args:
+ method: The inference method.
+ task_name (str): The name of the task.
+ extra_kwargs (dict): Additional keyword arguments for the method.
+ results_bag: The results bag to store evaluation results.
+ """
+ torch.manual_seed(SEED)
+ task = get_task(task_name)
+ num_simulations = NUM_SIMULATIONS // NUM_ROUNDS_SEQUENTIAL
+ thetas, xs = task.get_data(num_simulations)
+ prior = task.get_prior()
+ idx_eval = NUM_EVALUATION_OBS_SEQ
+ x_o = task.get_observation(idx_eval)
+ simulator = task.get_simulator()
+
+ # Round 1
+ inference = inference_class(prior, **extra_kwargs)
+ _ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)
+
+ for _ in range(NUM_ROUNDS_SEQUENTIAL - 1):
+ proposal = inference.build_posterior().set_default_x(x_o)
+ thetas_i = proposal.sample((num_simulations,))
+ xs_i = simulator(thetas_i)
+ if "npe" in inference_class.__name__.lower():
+ # NPE_C requires a Gaussian prior
+ _ = inference.append_simulations(thetas_i, xs_i, proposal=proposal).train(
+ **TRAIN_KWARGS
+ )
+ else:
+ inference.append_simulations(thetas_i, xs_i).train(**TRAIN_KWARGS)
+
+ posterior = inference.build_posterior()
+
+ c2st_val = eval_c2st(posterior, task, idx_eval)
+
+ # Cache results
+ results_bag.metric = c2st_val
+ results_bag.num_simulations = NUM_SIMULATIONS
+ results_bag.task_name = task_name
+ results_bag.method = inference_class.__name__ + str(extra_kwargs)
+
+
+@pytest.mark.benchmark
+@pytest.mark.parametrize("task_name", TASKS, ids=str)
+def test_run_benchmark(
+ inference_class,
+ task_name: str,
+ results_bag,
+ extra_kwargs: dict,
+ benchmark_mode: str,
+) -> None:
+ """
+ Benchmark test for amortized and sequential inference methods.
+
+ Args:
+ inference_class: The inference class to test i.e. NPE, NLE, NRE ...
+ task_name: The name of the task.
+ results_bag: The results bag to store evaluation results.
+ extra_kwargs: Additional keyword arguments for the method.
+ benchmark_mode: The benchmark mode. This is a fixture which based on user
+ input, determines which type of methods should be run.
+ """
+ if benchmark_mode in ["snpe", "snle", "snre"]:
+ train_and_eval_sequential_inference(
+ inference_class, task_name, extra_kwargs, results_bag
+ )
+ else:
+ train_and_eval_amortized_inference(
+ inference_class, task_name, extra_kwargs, results_bag
+ )
diff --git a/tests/conftest.py b/tests/conftest.py
index 9cedd97e9..04ac14069 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,5 +1,11 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see
+import pickle
+import re
+import shutil
+from logging import warning
+from pathlib import Path
+from shutil import rmtree
import pytest
import torch
@@ -8,6 +14,7 @@
# Seed for `set_seed` fixture. Change to random state of all seeded tests.
seed = 1
+harvested_fixture_data = None
# Use seed automatically for every test function.
@@ -29,10 +36,171 @@ def pytest_collection_modifyitems(config, items):
)
if not gpu_device_available:
skip_gpu = pytest.mark.skip(reason="No devices available")
+
for item in items:
if "gpu" in item.keywords:
item.add_marker(skip_gpu)
+ if not config.getoption("--bm"):
+ # Skip marked benchmarking tests
+ skip_bm = pytest.mark.skip(reason="Benchmarking disabled")
+ for item in items:
+ if "benchmark" in item.keywords:
+ item.add_marker(skip_bm)
+ else:
+ # Filter tests to only those with the 'benchmark' marker
+ filtered_items = []
+ for item in items:
+ if item.get_closest_marker("benchmark"):
+ filtered_items.append(item)
+
+ items[:] = filtered_items # Inplace!
+
+
+# Run mini-benchmark tests with `pytest --print-harvest`
+def pytest_addoption(parser):
+ parser.addoption(
+ "--bm",
+ action="store_true",
+ default=False,
+ help="Run mini-benchmark tests with specified mode",
+ )
+ parser.addoption(
+ "--bm-mode",
+ action="store",
+ default=None,
+ help="Run mini-benchmark tests with specified mode",
+ )
+
+
+@pytest.fixture
+def benchmark_mode(request):
+ """Fixture to access the --bm value in test files."""
+ return request.config.getoption("--bm-mode")
+
+
+@pytest.fixture(scope="session", autouse=True)
+def finalize_fixture_store(request, fixture_store):
+ # The code before `yield` runs at the start of the session (before tests).
+ yield
+ # The code after `yield` runs after all tests have completed.
+ # At this point, fixture_store should have all the harvested data.
+ global harvested_fixture_data
+ harvested_fixture_data = dict(fixture_store)
+
+
+def strip_ansi_escape_codes(text):
+ ansi_escape = re.compile(r'\x1b\[.*?m')
+ return ansi_escape.sub('', text)
+
+
+# Function to center text with ANSI colors, adjusting for escape codes
+def center_colored_text(text, width):
+ visible_length = len(strip_ansi_escape_codes(text))
+ padding = max(0, (width - visible_length) // 2)
+ return " " * padding + text + " " * (width - visible_length - padding)
+
+
+def pytest_terminal_summary(terminalreporter, exitstatus, config):
+ """
+ Custom pytest terminal summary to display mini SBIBM results with relative coloring
+ per task.
+
+ This function is called after the test session ends and generates a summary
+ of the results if the `--bm` option is specified. It displays the results
+ in a formatted table with methods as rows and tasks as columns, applying
+ relative coloring to metrics based on their performance within each task.
+ """
+ if config.getoption("--bm"):
+ terminal_width = shutil.get_terminal_size().columns
+ summary_text = " mini SBIBM results "
+ centered_line = summary_text.center(terminal_width, '=')
+ colored_line = f"\033[96m{centered_line}\033[0m"
+ terminalreporter.write_line(colored_line)
+
+ if harvested_fixture_data is not None:
+ terminalreporter.write_line("Amortized inference:")
+
+ results = harvested_fixture_data["results_bag"]
+
+ # Extract relevant data (method, task, metric)
+ methods = set()
+ tasks = set()
+ data = {} # (method, task) -> metric
+
+ for _, info in results.items():
+ method = info.get('method')
+ task = info.get('task_name')
+ metric = info.get('metric')
+
+ if method is not None and task is not None:
+ methods.add(method)
+ tasks.add(task)
+ data[(method, task)] = metric
+
+ methods = sorted(methods)
+ tasks = sorted(tasks)
+
+ if not methods or not tasks:
+ terminalreporter.write_line("No methods or tasks found.")
+ return
+
+ # Determine column widths
+ method_col_width = max(len(m) for m in methods)
+ task_col_widths = {t: max(len(t), 10) for t in tasks}
+
+ # Print the header row
+ header = " " * (method_col_width + 2)
+ for t in tasks:
+ header += t.center(task_col_widths[t] + 2)
+ terminalreporter.write_line(header)
+
+ # Print separator line
+ sep_line = "-" * len(header)
+ terminalreporter.write_line(sep_line)
+
+ # Calculate min and max for each task
+ min_max_per_task = {}
+ for t in tasks:
+ task_metrics = [data.get((m, t), float('inf')) for m in methods]
+ min_max_per_task[t] = (min(task_metrics), max(task_metrics))
+
+ # Print each row with colored values
+ for m in methods:
+ row = m.ljust(method_col_width + 2)
+ for t in tasks:
+ val = data.get((m, t), "N/A")
+ if val == "N/A":
+ val_str = "N/A"
+ row += val_str.center(task_col_widths[t] + 2)
+ else:
+ val = float(val)
+ min_val, max_val = min_max_per_task[t]
+ normalized_val = (
+ (val - min_val) / (max_val - min_val)
+ if max_val > min_val
+ else 0.5
+ )
+
+ # Determine color based on normalized value
+ if normalized_val == 0.0:
+ color = "\033[92m" # Green for best
+ elif normalized_val == 1.0:
+ color = "\033[91m" # Red for worst
+ else:
+ color = f"\033[9{int(2 + normalized_val * 3)}m"
+
+ val_str = format(val, ".3f")
+ colored_val_str = f"{color}{val_str}\033[0m"
+
+ row += center_colored_text(
+ colored_val_str, task_col_widths[t] + 2
+ )
+
+ terminalreporter.write_line(row)
+ else:
+ terminalreporter.write_line("No harvested fixture data found yet.")
+
@pytest.fixture(scope="function")
def mcmc_params_accurate() -> dict:
@@ -44,3 +212,48 @@ def mcmc_params_accurate() -> dict:
def mcmc_params_fast() -> dict:
"""Fixture for MCMC parameters for fast tests."""
return dict(num_chains=1, thin=1, warmup_steps=1)
+
+
+# Pytest harvest xdist support.
+
+
+# Define the folder in which temporary worker's results will be stored
+RESULTS_PATH = Path('./.xdist_results/')
+RESULTS_PATH.mkdir(exist_ok=True)
+
+
+def pytest_harvest_xdist_init():
+ # reset the recipient folder
+ if RESULTS_PATH.exists():
+ rmtree(RESULTS_PATH)
+ RESULTS_PATH.mkdir(exist_ok=False)
+ return True
+
+
+def pytest_harvest_xdist_worker_dump(worker_id, session_items, fixture_store):
+ # persist session_items and fixture_store in the file system
+ with open(RESULTS_PATH / ('%s.pkl' % worker_id), 'wb') as f:
+ try:
+ pickle.dump((session_items, fixture_store), f)
+ except Exception as e:
+ warning(
+ "Error while pickling worker %s's harvested results: [%s] %s",
+ (worker_id, e.__class__, e),
+ )
+ return True
+
+
+def pytest_harvest_xdist_load():
+ # restore the saved objects from file system
+ workers_saved_material = dict()
+ for pkl_file in RESULTS_PATH.glob('*.pkl'):
+ wid = pkl_file.stem
+ with pkl_file.open('rb') as f:
+ workers_saved_material[wid] = pickle.load(f)
+ return workers_saved_material
+
+
+def pytest_harvest_xdist_cleanup():
+ # delete all temporary pickle files
+ rmtree(RESULTS_PATH)
+ return True
diff --git a/tests/mini_sbibm/__init__.py b/tests/mini_sbibm/__init__.py
new file mode 100644
index 000000000..92b408c75
--- /dev/null
+++ b/tests/mini_sbibm/__init__.py
@@ -0,0 +1,37 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+# NOTE: This is inspired by the sbibm-package
+
+from .base_task import Task
+from .gaussian_linear import GaussianLinear
+from .linear_mvg import LinearMVG2d
+from .slcp import Slcp
+from .two_moons import TwoMoons
+
+TASKS = {
+ "two_moons": TwoMoons,
+ "linear_mvg_2d": LinearMVG2d,
+ "gaussian_linear": GaussianLinear,
+ "slcp": Slcp,
+}
+
+
+def get_task(name: str, *args, **kwargs) -> Task:
+ """
+ Retrieve a task instance based on the given name.
+
+ Args:
+ name (str): The name of the task to retrieve.
+ Possible values are "two_moons", "linear_mvg_2d",
+ "gaussian_linear", and "slcp".
+
+ Returns:
+ object: An instance of the corresponding task class.
+
+ Raises:
+ ValueError: If the provided task name is unknown.
+ """
+ try:
+ return TASKS[name](*args, **kwargs)
+ except KeyError as err:
+ raise ValueError(f"Unknown task {name}") from err
diff --git a/tests/mini_sbibm/base_task.py b/tests/mini_sbibm/base_task.py
new file mode 100644
index 000000000..3a18a6633
--- /dev/null
+++ b/tests/mini_sbibm/base_task.py
@@ -0,0 +1,123 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+# NOTE: This is inspired by the sbibm-package
+
+import os
+from abc import ABC, abstractmethod
+from typing import Callable
+
+import torch
+from torch.distributions import Distribution
+
+PATH = os.path.dirname(__file__)
+
+
+class Task(ABC):
+ """
+ Abstract base class for a task in the SBI benchmark.
+
+ Args:
+ name (str): The name of the task.
+ """
+
+ def __init__(self, name: str):
+ self.name = name
+
+ @abstractmethod
+ def theta_dim(self) -> int:
+ """
+ Returns the dimensionality of the parameter space.
+
+ Returns:
+ int: The dimensionality of the parameter space.
+ """
+ pass
+
+ @abstractmethod
+ def x_dim(self) -> int:
+ """
+ Returns the dimensionality of the observation space.
+
+ Returns:
+ int: The dimensionality of the observation space.
+ """
+ pass
+
+ @abstractmethod
+ def get_prior(self) -> Distribution:
+ """
+ Returns the prior distribution over parameters.
+
+ Returns:
+ Distribution: The prior distribution.
+ """
+ pass
+
+ @abstractmethod
+ def get_simulator(self) -> Callable:
+ """
+ Returns the simulator function.
+
+ Returns:
+ Callable: The simulator function.
+ """
+ pass
+
+ def get_data(self, num_sims: int):
+ """
+ Generates data by sampling from the prior and simulating observations.
+
+ Args:
+ num_sims (int): The number of simulations to run.
+
+ Returns:
+ tuple: A tuple containing the sampled parameters and simulated observations.
+ """
+ thetas = self.get_prior().sample((num_sims,))
+ xs = self.get_simulator()(thetas)
+ return thetas, xs
+
+ def get_observation(self, idx: int) -> torch.Tensor:
+ """
+ Loads a specific observation from file.
+
+ Args:
+ idx (int): The index of the observation to load.
+
+ Returns:
+ torch.Tensor: The loaded observation.
+ """
+ x_o = torch.load(
+ PATH + os.sep + "files" + os.sep + f"{self.name}{os.sep}x_o_{idx}.pt"
+ )
+ return x_o
+
+ def get_true_parameters(self, idx: int) -> torch.Tensor:
+ """
+ Loads the true parameters for a specific observation from file.
+
+ Args:
+ idx (int): The index of the parameters to load.
+
+ Returns:
+ torch.Tensor: The loaded true parameters.
+ """
+ theta = torch.load(
+ PATH + os.sep + "files" + os.sep + f"{self.name}{os.sep}theta_{idx}.pt"
+ )
+ return theta
+
+ def get_reference_posterior_samples(self, idx: int) -> torch.Tensor:
+ """
+ Loads reference posterior samples for a specific observation from file.
+
+ Args:
+ idx (int): The index of the posterior samples to load.
+
+ Returns:
+ torch.Tensor: The loaded posterior samples.
+ """
+ posterior_samples = torch.load(
+ PATH + os.sep + "files" + os.sep + f"{self.name}{os.sep}samples_{idx}.pt"
+ )
+ return posterior_samples
diff --git a/tests/mini_sbibm/files/slcp/samples_1.pt b/tests/mini_sbibm/files/slcp/samples_1.pt
new file mode 100644
index 000000000..d0a1bd771
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_1.pt differ
diff --git a/tests/mini_sbibm/files/slcp/samples_10.pt b/tests/mini_sbibm/files/slcp/samples_10.pt
new file mode 100644
index 000000000..0de7efad6
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_10.pt differ
diff --git a/tests/mini_sbibm/files/slcp/samples_2.pt b/tests/mini_sbibm/files/slcp/samples_2.pt
new file mode 100644
index 000000000..f642597d3
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_2.pt differ
diff --git a/tests/mini_sbibm/files/slcp/samples_3.pt b/tests/mini_sbibm/files/slcp/samples_3.pt
new file mode 100644
index 000000000..640bcd1dc
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_3.pt differ
diff --git a/tests/mini_sbibm/files/slcp/samples_4.pt b/tests/mini_sbibm/files/slcp/samples_4.pt
new file mode 100644
index 000000000..1397dc02b
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_4.pt differ
diff --git a/tests/mini_sbibm/files/slcp/samples_5.pt b/tests/mini_sbibm/files/slcp/samples_5.pt
new file mode 100644
index 000000000..f2e8c35f1
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_5.pt differ
diff --git a/tests/mini_sbibm/files/slcp/samples_6.pt b/tests/mini_sbibm/files/slcp/samples_6.pt
new file mode 100644
index 000000000..091bb8143
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_6.pt differ
diff --git a/tests/mini_sbibm/files/slcp/samples_7.pt b/tests/mini_sbibm/files/slcp/samples_7.pt
new file mode 100644
index 000000000..edcbe6596
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_7.pt differ
diff --git a/tests/mini_sbibm/files/slcp/samples_8.pt b/tests/mini_sbibm/files/slcp/samples_8.pt
new file mode 100644
index 000000000..ec1cd0392
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_8.pt differ
diff --git a/tests/mini_sbibm/files/slcp/samples_9.pt b/tests/mini_sbibm/files/slcp/samples_9.pt
new file mode 100644
index 000000000..38b14665e
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_9.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_1.pt b/tests/mini_sbibm/files/slcp/theta_o_1.pt
new file mode 100644
index 000000000..ab4dc7bae
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_1.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_10.pt b/tests/mini_sbibm/files/slcp/theta_o_10.pt
new file mode 100644
index 000000000..4056ae9de
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_10.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_2.pt b/tests/mini_sbibm/files/slcp/theta_o_2.pt
new file mode 100644
index 000000000..529388c43
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_2.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_3.pt b/tests/mini_sbibm/files/slcp/theta_o_3.pt
new file mode 100644
index 000000000..97e333ced
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_3.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_4.pt b/tests/mini_sbibm/files/slcp/theta_o_4.pt
new file mode 100644
index 000000000..27f22b885
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_4.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_5.pt b/tests/mini_sbibm/files/slcp/theta_o_5.pt
new file mode 100644
index 000000000..64c3c77d4
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_5.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_6.pt b/tests/mini_sbibm/files/slcp/theta_o_6.pt
new file mode 100644
index 000000000..607ced052
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_6.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_7.pt b/tests/mini_sbibm/files/slcp/theta_o_7.pt
new file mode 100644
index 000000000..7da5b387b
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_7.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_8.pt b/tests/mini_sbibm/files/slcp/theta_o_8.pt
new file mode 100644
index 000000000..d4ccf87d0
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_8.pt differ
diff --git a/tests/mini_sbibm/files/slcp/theta_o_9.pt b/tests/mini_sbibm/files/slcp/theta_o_9.pt
new file mode 100644
index 000000000..38b6ecd7b
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_9.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_1.pt b/tests/mini_sbibm/files/slcp/x_o_1.pt
new file mode 100644
index 000000000..f806b3ff8
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_1.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_10.pt b/tests/mini_sbibm/files/slcp/x_o_10.pt
new file mode 100644
index 000000000..ddfe6c805
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_10.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_2.pt b/tests/mini_sbibm/files/slcp/x_o_2.pt
new file mode 100644
index 000000000..7c7f92274
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_2.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_3.pt b/tests/mini_sbibm/files/slcp/x_o_3.pt
new file mode 100644
index 000000000..de8576ce5
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_3.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_4.pt b/tests/mini_sbibm/files/slcp/x_o_4.pt
new file mode 100644
index 000000000..a60abaf24
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_4.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_5.pt b/tests/mini_sbibm/files/slcp/x_o_5.pt
new file mode 100644
index 000000000..a8f169e89
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_5.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_6.pt b/tests/mini_sbibm/files/slcp/x_o_6.pt
new file mode 100644
index 000000000..56d43231f
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_6.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_7.pt b/tests/mini_sbibm/files/slcp/x_o_7.pt
new file mode 100644
index 000000000..d88572ab9
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_7.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_8.pt b/tests/mini_sbibm/files/slcp/x_o_8.pt
new file mode 100644
index 000000000..35e6f5d21
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_8.pt differ
diff --git a/tests/mini_sbibm/files/slcp/x_o_9.pt b/tests/mini_sbibm/files/slcp/x_o_9.pt
new file mode 100644
index 000000000..89211253f
Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_9.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_1.pt b/tests/mini_sbibm/files/two_moons/samples_1.pt
new file mode 100644
index 000000000..4758ad613
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_1.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_10.pt b/tests/mini_sbibm/files/two_moons/samples_10.pt
new file mode 100644
index 000000000..22c9fb59a
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_10.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_2.pt b/tests/mini_sbibm/files/two_moons/samples_2.pt
new file mode 100644
index 000000000..79c27d20f
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_2.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_3.pt b/tests/mini_sbibm/files/two_moons/samples_3.pt
new file mode 100644
index 000000000..41300b513
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_3.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_4.pt b/tests/mini_sbibm/files/two_moons/samples_4.pt
new file mode 100644
index 000000000..99164ddf8
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_4.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_5.pt b/tests/mini_sbibm/files/two_moons/samples_5.pt
new file mode 100644
index 000000000..6c2d0a4bc
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_5.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_6.pt b/tests/mini_sbibm/files/two_moons/samples_6.pt
new file mode 100644
index 000000000..7269844cb
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_6.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_7.pt b/tests/mini_sbibm/files/two_moons/samples_7.pt
new file mode 100644
index 000000000..a75fddabb
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_7.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_8.pt b/tests/mini_sbibm/files/two_moons/samples_8.pt
new file mode 100644
index 000000000..6f9f30df3
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_8.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/samples_9.pt b/tests/mini_sbibm/files/two_moons/samples_9.pt
new file mode 100644
index 000000000..a8fe22661
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_9.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_1.pt b/tests/mini_sbibm/files/two_moons/theta_o_1.pt
new file mode 100644
index 000000000..5aa284381
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_1.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_10.pt b/tests/mini_sbibm/files/two_moons/theta_o_10.pt
new file mode 100644
index 000000000..4d3c22045
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_10.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_2.pt b/tests/mini_sbibm/files/two_moons/theta_o_2.pt
new file mode 100644
index 000000000..33a83c3c4
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_2.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_3.pt b/tests/mini_sbibm/files/two_moons/theta_o_3.pt
new file mode 100644
index 000000000..fb783e50b
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_3.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_4.pt b/tests/mini_sbibm/files/two_moons/theta_o_4.pt
new file mode 100644
index 000000000..3bc059efc
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_4.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_5.pt b/tests/mini_sbibm/files/two_moons/theta_o_5.pt
new file mode 100644
index 000000000..c3aceb4dc
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_5.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_6.pt b/tests/mini_sbibm/files/two_moons/theta_o_6.pt
new file mode 100644
index 000000000..fc93bd5ab
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_6.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_7.pt b/tests/mini_sbibm/files/two_moons/theta_o_7.pt
new file mode 100644
index 000000000..16eb771e2
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_7.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_8.pt b/tests/mini_sbibm/files/two_moons/theta_o_8.pt
new file mode 100644
index 000000000..a91ed39f6
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_8.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/theta_o_9.pt b/tests/mini_sbibm/files/two_moons/theta_o_9.pt
new file mode 100644
index 000000000..9ce7a7821
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_9.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_1.pt b/tests/mini_sbibm/files/two_moons/x_o_1.pt
new file mode 100644
index 000000000..f9e20060e
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_1.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_10.pt b/tests/mini_sbibm/files/two_moons/x_o_10.pt
new file mode 100644
index 000000000..fef3e1310
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_10.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_2.pt b/tests/mini_sbibm/files/two_moons/x_o_2.pt
new file mode 100644
index 000000000..f55bf4cba
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_2.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_3.pt b/tests/mini_sbibm/files/two_moons/x_o_3.pt
new file mode 100644
index 000000000..1b9b4e572
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_3.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_4.pt b/tests/mini_sbibm/files/two_moons/x_o_4.pt
new file mode 100644
index 000000000..e0abd7c64
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_4.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_5.pt b/tests/mini_sbibm/files/two_moons/x_o_5.pt
new file mode 100644
index 000000000..015495b4f
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_5.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_6.pt b/tests/mini_sbibm/files/two_moons/x_o_6.pt
new file mode 100644
index 000000000..450b8e26f
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_6.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_7.pt b/tests/mini_sbibm/files/two_moons/x_o_7.pt
new file mode 100644
index 000000000..10c547b71
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_7.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_8.pt b/tests/mini_sbibm/files/two_moons/x_o_8.pt
new file mode 100644
index 000000000..b752885b7
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_8.pt differ
diff --git a/tests/mini_sbibm/files/two_moons/x_o_9.pt b/tests/mini_sbibm/files/two_moons/x_o_9.pt
new file mode 100644
index 000000000..f9c9ad79f
Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_9.pt differ
diff --git a/tests/mini_sbibm/gaussian_linear.py b/tests/mini_sbibm/gaussian_linear.py
new file mode 100644
index 000000000..13b5f7dd4
--- /dev/null
+++ b/tests/mini_sbibm/gaussian_linear.py
@@ -0,0 +1,119 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+# NOTE: This is inspired by the sbibm-package
+
+from functools import partial
+from typing import Callable
+
+import torch
+from torch.distributions import Distribution, MultivariateNormal
+
+from sbi.simulators.linear_gaussian import (
+ diagonal_linear_gaussian,
+ true_posterior_linear_gaussian_mvn_prior,
+)
+
+from .base_task import Task
+
+
+class GaussianLinear(Task):
+ """
+ Task for the Gaussian Linear model.
+
+ This task uses a linear Gaussian model with a multivariate normal prior.
+ """
+
+ def __init__(self):
+ """
+ Initializes the GaussianLinear task.
+ """
+ self.simulator_scale = 0.1
+ self.dim = 5
+ super().__init__("gaussian_linear")
+
+ def theta_dim(self) -> int:
+ """
+ Returns the dimensionality of the parameter space.
+
+ Returns:
+ int: The dimensionality of the parameter space.
+ """
+ return self.dim
+
+ def x_dim(self) -> int:
+ """
+ Returns the dimensionality of the observation space.
+
+ Returns:
+ int: The dimensionality of the observation space.
+ """
+ return self.dim
+
+ def get_reference_posterior_samples(self, idx: int) -> torch.Tensor:
+ """
+ Generates reference posterior samples for a specific observation.
+
+ Args:
+ idx (int): The index of the observation.
+
+ Returns:
+ torch.Tensor: The reference posterior samples.
+ """
+ x_o = self.get_observation(idx)
+ posterior = true_posterior_linear_gaussian_mvn_prior(
+ x_o,
+ torch.zeros(self.dim),
+ self.simulator_scale * torch.eye(self.dim),
+ torch.zeros(self.dim),
+ torch.eye(self.dim),
+ )
+
+ return posterior.sample((10_000,))
+
+ def get_true_parameters(self, idx: int) -> torch.Tensor:
+ """
+ Generates the true parameters for a specific observation.
+
+ Args:
+ idx (int): The index of the observation.
+
+ Returns:
+ torch.Tensor: The true parameters.
+ """
+ torch.manual_seed(idx)
+ return self.get_prior().sample()
+
+ def get_observation(self, idx: int) -> torch.Tensor:
+ """
+ Generates an observation for a specific set of true parameters.
+
+ Args:
+ idx (int): The index of the observation.
+
+ Returns:
+ torch.Tensor: The observation.
+ """
+ theta_o = self.get_true_parameters(idx)
+ x_o = self.get_simulator()(theta_o[None, :])[0]
+ return x_o
+
+ def get_prior(self) -> Distribution:
+ """
+ Returns the prior distribution over parameters.
+
+ Returns:
+ Distribution: The prior distribution.
+ """
+ return MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim))
+
+ def get_simulator(self) -> Callable:
+ """
+ Returns the simulator function.
+
+ Returns:
+ Callable: The simulator function.
+ """
+ return partial(
+ diagonal_linear_gaussian,
+ std=self.simulator_scale,
+ )
diff --git a/tests/mini_sbibm/linear_mvg.py b/tests/mini_sbibm/linear_mvg.py
new file mode 100644
index 000000000..cb2f26153
--- /dev/null
+++ b/tests/mini_sbibm/linear_mvg.py
@@ -0,0 +1,120 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+# NOTE: This is inspired by the sbibm-package
+
+from functools import partial
+from typing import Callable
+
+import torch
+from torch.distributions import Distribution, MultivariateNormal
+
+from sbi.simulators.linear_gaussian import (
+ linear_gaussian,
+ true_posterior_linear_gaussian_mvn_prior,
+)
+
+from .base_task import Task
+
+
+class LinearMVG2d(Task):
+ """
+ Task for the Linear Multivariate Gaussian (MVG) model in 2D.
+
+ This task uses a linear Gaussian model with a multivariate normal prior.
+ """
+
+ def __init__(self):
+ """
+ Initializes the LinearMVG2d task.
+ """
+ self.likelihood_shift = torch.tensor([-1.0, 1.0])
+ self.likelihood_cov = torch.tensor([[0.6, 0.5], [0.5, 0.6]])
+ super().__init__("linear_mvg_2d")
+
+ def theta_dim(self) -> int:
+ """
+ Returns the dimensionality of the parameter space.
+
+ Returns:
+ int: The dimensionality of the parameter space.
+ """
+ return 2
+
+ def x_dim(self) -> int:
+ """
+ Returns the dimensionality of the observation space.
+
+ Returns:
+ int: The dimensionality of the observation space.
+ """
+ return 2
+
+ def get_reference_posterior_samples(self, idx: int) -> torch.Tensor:
+ """
+ Generates reference posterior samples for a specific observation.
+
+ Args:
+ idx (int): The index of the observation.
+
+ Returns:
+ torch.Tensor: The reference posterior samples.
+ """
+ x_o = self.get_observation(idx)
+ posterior = true_posterior_linear_gaussian_mvn_prior(
+ x_o,
+ self.likelihood_shift,
+ self.likelihood_cov,
+ torch.zeros(2),
+ torch.eye(2),
+ )
+
+ return posterior.sample((10_000,))
+
+ def get_true_parameters(self, idx: int) -> torch.Tensor:
+ """
+ Generates the true parameters for a specific observation.
+
+ Args:
+ idx (int): The index of the observation.
+
+ Returns:
+ torch.Tensor: The true parameters.
+ """
+ torch.manual_seed(idx)
+ return self.get_prior().sample()
+
+ def get_observation(self, idx: int) -> torch.Tensor:
+ """
+ Generates an observation for a specific set of true parameters.
+
+ Args:
+ idx (int): The index of the observation.
+
+ Returns:
+ torch.Tensor: The observation.
+ """
+ theta_o = self.get_true_parameters(idx)
+ x_o = self.get_simulator()(theta_o[None, :])[0]
+ return x_o
+
+ def get_prior(self) -> Distribution:
+ """
+ Returns the prior distribution over parameters.
+
+ Returns:
+ Distribution: The prior distribution.
+ """
+ return MultivariateNormal(torch.zeros(2), torch.eye(2))
+
+ def get_simulator(self) -> Callable:
+ """
+ Returns the simulator function.
+
+ Returns:
+ Callable: The simulator function.
+ """
+ return partial(
+ linear_gaussian,
+ likelihood_shift=self.likelihood_shift,
+ likelihood_cov=self.likelihood_cov,
+ )
diff --git a/tests/mini_sbibm/slcp.py b/tests/mini_sbibm/slcp.py
new file mode 100644
index 000000000..aea4d8314
--- /dev/null
+++ b/tests/mini_sbibm/slcp.py
@@ -0,0 +1,56 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+# NOTE: This is inspired by the sbibm-package
+
+from typing import Callable
+
+import torch
+from torch.distributions import Distribution, Independent, MultivariateNormal, Uniform
+
+from .base_task import Task
+
+
+def simulator(theta, num_data=4):
+ num_samples = theta.shape[0]
+
+ m = torch.stack((theta[:, [0]].squeeze(), theta[:, [1]].squeeze())).T
+ if m.dim() == 1:
+ m.unsqueeze_(0)
+
+ s1 = theta[:, [2]].squeeze() ** 2
+ s2 = theta[:, [3]].squeeze() ** 2
+ rho = torch.nn.Tanh()(theta[:, [4]]).squeeze()
+
+ S = torch.empty((num_samples, 2, 2))
+ S[:, 0, 0] = s1**2
+ S[:, 0, 1] = rho * s1 * s2
+ S[:, 1, 0] = rho * s1 * s2
+ S[:, 1, 1] = s2**2
+
+ # Add eps to diagonal to ensure PSD
+ eps = 0.000001
+ S[:, 0, 0] += eps
+ S[:, 1, 1] += eps
+
+ data_dist = MultivariateNormal(m, S)
+ xs = data_dist.sample((num_data,))
+ xs = xs.permute(1, 0, 2)
+
+ return xs.reshape(num_samples, num_data * 2)
+
+
+class Slcp(Task):
+ def __init__(self):
+ super().__init__("slcp")
+
+ def theta_dim(self) -> int:
+ return 5
+
+ def x_dim(self) -> int:
+ return 8
+
+ def get_prior(self) -> Distribution:
+ return Independent(Uniform(-3 * torch.ones(5), 3 * torch.ones(5)), 1)
+
+ def get_simulator(self) -> Callable:
+ return simulator
diff --git a/tests/mini_sbibm/two_moons.py b/tests/mini_sbibm/two_moons.py
new file mode 100644
index 000000000..1316cc128
--- /dev/null
+++ b/tests/mini_sbibm/two_moons.py
@@ -0,0 +1,127 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+# NOTE: This is inspired by the sbibm-package
+
+import math
+from typing import Callable
+
+import torch
+from torch.distributions import Distribution, Independent, Normal, Uniform
+
+from .base_task import Task
+
+
+def _map_fun(parameters: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
+ """
+ Maps the parameters and points to a new space using a rotation.
+
+ Args:
+ parameters (torch.Tensor): The parameters for the mapping.
+ p (torch.Tensor): The points to be mapped.
+
+ Returns:
+ torch.Tensor: The mapped points.
+ """
+ ang = torch.tensor([-math.pi / 4.0])
+ c = torch.cos(ang)
+ s = torch.sin(ang)
+ z0 = (c * parameters[:, 0] - s * parameters[:, 1]).reshape(-1, 1)
+ z1 = (s * parameters[:, 0] + c * parameters[:, 1]).reshape(-1, 1)
+ return p + torch.cat((-torch.abs(z0), z1), dim=1)
+
+
+def simulator(
+ parameters: torch.Tensor,
+ r_loc: float = 0.1,
+ r_scale: float = 0.01,
+ a_low: float = -math.pi / 2.0,
+ a_high: float = math.pi / 2.0,
+ base_offset: float = 0.25,
+) -> torch.Tensor:
+ """
+ Simulator function for the Two Moons task.
+
+ Args:
+ parameters (torch.Tensor): The parameters for the simulator.
+ r_loc (float, optional): The mean of the radius distribution. Defaults to 0.1.
+ r_scale (float, optional): The standard deviation of the radius distribution.
+ Defaults to 0.01.
+ a_low (float, optional): The lower bound of the angle distribution. Defaults to
+ -math.pi / 2.0.
+ a_high (float, optional): The upper bound of the angle distribution. Defaults to
+ math.pi / 2.0.
+ base_offset (float, optional): The base offset for the points. Defaults to 0.25.
+
+ Returns:
+ torch.Tensor: The simulated data.
+ """
+ num_samples = parameters.shape[0]
+
+ a_dist = Uniform(
+ low=a_low,
+ high=a_high,
+ )
+ a = a_dist.sample((num_samples, 1))
+
+ r_dist = Normal(r_loc, r_scale)
+ r = r_dist.sample((num_samples, 1))
+
+ p = torch.cat(
+ (
+ torch.cos(a) * r + base_offset,
+ torch.sin(a) * r,
+ ),
+ dim=1,
+ )
+
+ return _map_fun(parameters, p)
+
+
+class TwoMoons(Task):
+ """
+ Task for the Two Moons model.
+
+ This task uses a uniform prior and a custom simulator.
+ """
+
+ def __init__(self):
+ """
+ Initializes the TwoMoons task.
+ """
+ super().__init__("two_moons")
+
+ def theta_dim(self) -> int:
+ """
+ Returns the dimensionality of the parameter space.
+
+ Returns:
+ int: The dimensionality of the parameter space.
+ """
+ return 2
+
+ def x_dim(self) -> int:
+ """
+ Returns the dimensionality of the observation space.
+
+ Returns:
+ int: The dimensionality of the observation space.
+ """
+ return 2
+
+ def get_prior(self) -> Distribution:
+ """
+ Returns the prior distribution over parameters.
+
+ Returns:
+ Distribution: The prior distribution.
+ """
+ return Independent(Uniform(-torch.ones(2), torch.ones(2)), 1)
+
+ def get_simulator(self) -> Callable:
+ """
+ Returns the simulator function.
+
+ Returns:
+ Callable: The simulator function.
+ """
+ return simulator
diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py
index cd275a7b0..3b330a91c 100644
--- a/tests/user_input_checks_test.py
+++ b/tests/user_input_checks_test.py
@@ -85,7 +85,7 @@ def matrix_simulator(theta):
# Set default tensor locally to reach tensors in fixtures.
-torch.set_default_tensor_type(torch.FloatTensor)
+torch.set_default_dtype(torch.float32)
@pytest.mark.parametrize(