diff --git a/pyproject.toml b/pyproject.toml index 71c270459..cf180362d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dev = [ "pytest", "pytest-cov", "pytest-xdist", + "pytest-harvest", "torchtestcase", ] @@ -131,7 +132,8 @@ testpaths = ["tests"] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "gpu: marks tests that require a gpu (deselect with '-m \"not gpu\"')", - "mcmc: marks tests that require MCMC sampling (deselect with '-m \"not mcmc\"')" + "mcmc: marks tests that require MCMC sampling (deselect with '-m \"not mcmc\"')", + "benchmark: marks test that are soley for benchmarking purposes" ] xfail_strict = true diff --git a/tests/bm_test.py b/tests/bm_test.py new file mode 100644 index 000000000..e38fd9659 --- /dev/null +++ b/tests/bm_test.py @@ -0,0 +1,272 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see + +import pytest +import torch +from pytest_harvest import ResultsBag + +from sbi.inference import FMPE, NLE, NPE, NPSE, NRE +from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.trainers.npe import NPE_C +from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C +from sbi.utils.metrics import c2st + +from .mini_sbibm import get_task +from .mini_sbibm.base_task import Task + +# Global settings +SEED = 0 +TASKS = ["two_moons", "linear_mvg_2d", "gaussian_linear", "slcp"] +NUM_SIMULATIONS = 2000 +NUM_EVALUATION_OBS = 3 # Currently only 3 observation tested for speed +NUM_ROUNDS_SEQUENTIAL = 2 +NUM_EVALUATION_OBS_SEQ = 1 +TRAIN_KWARGS = {} + +# Density estimators to test +DENSITY_ESTIMATORS = ["mdn", "made", "maf", "nsf", "maf_rqs"] # "Kinda exhaustive" +CLASSIFIERS = ["mlp", "resnet"] +NNS = ["mlp", "resnet"] +SCORE_ESTIMATORS = ["mlp", "ada_mlp"] + +# Benchmarking method groups i.e. what to run for different --bm-mode +METHOD_GROUPS = { + "none": [NPE, NRE, NLE, FMPE, NPSE], + "npe": [NPE], + "nle": [NLE], + "nre": [NRE_A, NRE_B, NRE_C, BNRE], + "fmpe": [FMPE], + "npse": [NPSE], + "snpe": [NPE_C], # NPE_B not implemented, NPE_A need Gaussian prior + "snle": [NLE], + "snre": [NRE_A, NRE_B, NRE_C, BNRE], +} +METHOD_PARAMS = { + "none": [{}], + "npe": [{"density_estimator": de} for de in DENSITY_ESTIMATORS], + "nle": [{"density_estimator": de} for de in ["maf", "nsf"]], + "nre": [{"classifier": cl} for cl in CLASSIFIERS], + "fmpe": [{"density_estimator": nn} for nn in NNS], + "npse": [ + {"score_estimator": nn, "sde_type": sde} + for nn in SCORE_ESTIMATORS + for sde in ["ve", "vp"] + ], + "snpe": [{}], + "snle": [{}], + "snre": [{}], +} + + +@pytest.fixture +def method_list(benchmark_mode: str) -> list: + """ + Fixture to get the list of methods based on the benchmark mode. + + Args: + benchmark_mode (str): The benchmark mode. + + Returns: + list: List of methods for the given benchmark mode. + """ + name = str(benchmark_mode).lower() + if name not in METHOD_GROUPS: + raise ValueError(f"Benchmark mode '{benchmark_mode}' is not supported.") + return METHOD_GROUPS[name] + + +@pytest.fixture +def kwargs_list(benchmark_mode: str) -> list: + """ + Fixture to get the list of kwargs based on the benchmark mode. + + Args: + benchmark_mode (str): The benchmark mode. + + Returns: + list: List of kwargs for the given benchmark mode. + """ + name = str(benchmark_mode).lower() + if name not in METHOD_PARAMS: + raise ValueError(f"Benchmark mode '{benchmark_mode}' is not supported.") + return METHOD_PARAMS[name] + + +# Use pytest.mark.parametrize dynamically +# Generates a list of methods to test based on the benchmark mode +def pytest_generate_tests(metafunc): + """ + Dynamically generates a list of methods to test based on the benchmark mode. + + Args: + metafunc: The metafunc object from pytest. + """ + if "inference_class" in metafunc.fixturenames: + method_list = metafunc.config.getoption("--bm-mode") + name = str(method_list).lower() + method_group = METHOD_GROUPS.get(name, []) + metafunc.parametrize("inference_class", method_group) + if "extra_kwargs" in metafunc.fixturenames: + kwargs_list = metafunc.config.getoption("--bm-mode") + name = str(kwargs_list).lower() + kwargs_group = METHOD_PARAMS.get(name, []) + metafunc.parametrize("extra_kwargs", kwargs_group) + + +def standard_eval_c2st_loop(posterior: NeuralPosterior, task: Task) -> float: + """ + Evaluates the C2ST metric for the given posterior and task. + + Args: + posterior: The posterior distribution. + task: The task object. + + Returns: + float: The mean C2ST value. + """ + c2st_scores = [] + for i in range(1, NUM_EVALUATION_OBS + 1): + c2st_val = eval_c2st(posterior, task, i) + c2st_scores.append(c2st_val) + + mean_c2st = sum(c2st_scores) / len(c2st_scores) + # Convert to float rounded to 3 decimal places + mean_c2st = float(f"{mean_c2st:.3f}") + return mean_c2st + + +def eval_c2st( + posterior: NeuralPosterior, + task: Task, + idx_observation: int, + num_samples: int = 1000, +) -> float: + """ + Evaluates the C2ST metric for a specific observation. + + Args: + posterior: The posterior distribution. + task: The task object. + i (int): The observation index. + + Returns: + float: The C2ST value. + """ + x_o = task.get_observation(idx_observation) + posterior_samples = task.get_reference_posterior_samples(idx_observation) + approx_posterior_samples = posterior.sample((num_samples,), x=x_o) + if isinstance(approx_posterior_samples, tuple): + approx_posterior_samples = approx_posterior_samples[0] + assert posterior_samples.shape[0] >= num_samples, "Not enough reference samples" + c2st_val = c2st(posterior_samples[:num_samples], approx_posterior_samples) + return float(c2st_val) + + +def train_and_eval_amortized_inference( + inference_class, task_name: str, extra_kwargs: dict, results_bag: ResultsBag +) -> None: + """ + Performs amortized inference evaluation. + + Args: + method: The inference method. + task_name: The name of the task. + extra_kwargs: Additional keyword arguments for the method. + results_bag: The results bag to store evaluation results. Subclass of dict, but + allows item assignment with dot notation. + """ + torch.manual_seed(SEED) + task = get_task(task_name) + thetas, xs = task.get_data(NUM_SIMULATIONS) + prior = task.get_prior() + + inference = inference_class(prior, **extra_kwargs) + _ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS) + + posterior = inference.build_posterior() + + mean_c2st = standard_eval_c2st_loop(posterior, task) + + # Cache results + results_bag.metric = mean_c2st + results_bag.num_simulations = NUM_SIMULATIONS + results_bag.task_name = task_name + results_bag.method = inference_class.__name__ + str(extra_kwargs) + + +def train_and_eval_sequential_inference( + inference_class, task_name: str, extra_kwargs: dict, results_bag: ResultsBag +) -> None: + """ + Performs sequential inference evaluation. + + Args: + method: The inference method. + task_name (str): The name of the task. + extra_kwargs (dict): Additional keyword arguments for the method. + results_bag: The results bag to store evaluation results. + """ + torch.manual_seed(SEED) + task = get_task(task_name) + num_simulations = NUM_SIMULATIONS // NUM_ROUNDS_SEQUENTIAL + thetas, xs = task.get_data(num_simulations) + prior = task.get_prior() + idx_eval = NUM_EVALUATION_OBS_SEQ + x_o = task.get_observation(idx_eval) + simulator = task.get_simulator() + + # Round 1 + inference = inference_class(prior, **extra_kwargs) + _ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS) + + for _ in range(NUM_ROUNDS_SEQUENTIAL - 1): + proposal = inference.build_posterior().set_default_x(x_o) + thetas_i = proposal.sample((num_simulations,)) + xs_i = simulator(thetas_i) + if "npe" in inference_class.__name__.lower(): + # NPE_C requires a Gaussian prior + _ = inference.append_simulations(thetas_i, xs_i, proposal=proposal).train( + **TRAIN_KWARGS + ) + else: + inference.append_simulations(thetas_i, xs_i).train(**TRAIN_KWARGS) + + posterior = inference.build_posterior() + + c2st_val = eval_c2st(posterior, task, idx_eval) + + # Cache results + results_bag.metric = c2st_val + results_bag.num_simulations = NUM_SIMULATIONS + results_bag.task_name = task_name + results_bag.method = inference_class.__name__ + str(extra_kwargs) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("task_name", TASKS, ids=str) +def test_run_benchmark( + inference_class, + task_name: str, + results_bag, + extra_kwargs: dict, + benchmark_mode: str, +) -> None: + """ + Benchmark test for amortized and sequential inference methods. + + Args: + inference_class: The inference class to test i.e. NPE, NLE, NRE ... + task_name: The name of the task. + results_bag: The results bag to store evaluation results. + extra_kwargs: Additional keyword arguments for the method. + benchmark_mode: The benchmark mode. This is a fixture which based on user + input, determines which type of methods should be run. + """ + if benchmark_mode in ["snpe", "snle", "snre"]: + train_and_eval_sequential_inference( + inference_class, task_name, extra_kwargs, results_bag + ) + else: + train_and_eval_amortized_inference( + inference_class, task_name, extra_kwargs, results_bag + ) diff --git a/tests/conftest.py b/tests/conftest.py index 9cedd97e9..04ac14069 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,11 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +import pickle +import re +import shutil +from logging import warning +from pathlib import Path +from shutil import rmtree import pytest import torch @@ -8,6 +14,7 @@ # Seed for `set_seed` fixture. Change to random state of all seeded tests. seed = 1 +harvested_fixture_data = None # Use seed automatically for every test function. @@ -29,10 +36,171 @@ def pytest_collection_modifyitems(config, items): ) if not gpu_device_available: skip_gpu = pytest.mark.skip(reason="No devices available") + for item in items: if "gpu" in item.keywords: item.add_marker(skip_gpu) + if not config.getoption("--bm"): + # Skip marked benchmarking tests + skip_bm = pytest.mark.skip(reason="Benchmarking disabled") + for item in items: + if "benchmark" in item.keywords: + item.add_marker(skip_bm) + else: + # Filter tests to only those with the 'benchmark' marker + filtered_items = [] + for item in items: + if item.get_closest_marker("benchmark"): + filtered_items.append(item) + + items[:] = filtered_items # Inplace! + + +# Run mini-benchmark tests with `pytest --print-harvest` +def pytest_addoption(parser): + parser.addoption( + "--bm", + action="store_true", + default=False, + help="Run mini-benchmark tests with specified mode", + ) + parser.addoption( + "--bm-mode", + action="store", + default=None, + help="Run mini-benchmark tests with specified mode", + ) + + +@pytest.fixture +def benchmark_mode(request): + """Fixture to access the --bm value in test files.""" + return request.config.getoption("--bm-mode") + + +@pytest.fixture(scope="session", autouse=True) +def finalize_fixture_store(request, fixture_store): + # The code before `yield` runs at the start of the session (before tests). + yield + # The code after `yield` runs after all tests have completed. + # At this point, fixture_store should have all the harvested data. + global harvested_fixture_data + harvested_fixture_data = dict(fixture_store) + + +def strip_ansi_escape_codes(text): + ansi_escape = re.compile(r'\x1b\[.*?m') + return ansi_escape.sub('', text) + + +# Function to center text with ANSI colors, adjusting for escape codes +def center_colored_text(text, width): + visible_length = len(strip_ansi_escape_codes(text)) + padding = max(0, (width - visible_length) // 2) + return " " * padding + text + " " * (width - visible_length - padding) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + """ + Custom pytest terminal summary to display mini SBIBM results with relative coloring + per task. + + This function is called after the test session ends and generates a summary + of the results if the `--bm` option is specified. It displays the results + in a formatted table with methods as rows and tasks as columns, applying + relative coloring to metrics based on their performance within each task. + """ + if config.getoption("--bm"): + terminal_width = shutil.get_terminal_size().columns + summary_text = " mini SBIBM results " + centered_line = summary_text.center(terminal_width, '=') + colored_line = f"\033[96m{centered_line}\033[0m" + terminalreporter.write_line(colored_line) + + if harvested_fixture_data is not None: + terminalreporter.write_line("Amortized inference:") + + results = harvested_fixture_data["results_bag"] + + # Extract relevant data (method, task, metric) + methods = set() + tasks = set() + data = {} # (method, task) -> metric + + for _, info in results.items(): + method = info.get('method') + task = info.get('task_name') + metric = info.get('metric') + + if method is not None and task is not None: + methods.add(method) + tasks.add(task) + data[(method, task)] = metric + + methods = sorted(methods) + tasks = sorted(tasks) + + if not methods or not tasks: + terminalreporter.write_line("No methods or tasks found.") + return + + # Determine column widths + method_col_width = max(len(m) for m in methods) + task_col_widths = {t: max(len(t), 10) for t in tasks} + + # Print the header row + header = " " * (method_col_width + 2) + for t in tasks: + header += t.center(task_col_widths[t] + 2) + terminalreporter.write_line(header) + + # Print separator line + sep_line = "-" * len(header) + terminalreporter.write_line(sep_line) + + # Calculate min and max for each task + min_max_per_task = {} + for t in tasks: + task_metrics = [data.get((m, t), float('inf')) for m in methods] + min_max_per_task[t] = (min(task_metrics), max(task_metrics)) + + # Print each row with colored values + for m in methods: + row = m.ljust(method_col_width + 2) + for t in tasks: + val = data.get((m, t), "N/A") + if val == "N/A": + val_str = "N/A" + row += val_str.center(task_col_widths[t] + 2) + else: + val = float(val) + min_val, max_val = min_max_per_task[t] + normalized_val = ( + (val - min_val) / (max_val - min_val) + if max_val > min_val + else 0.5 + ) + + # Determine color based on normalized value + if normalized_val == 0.0: + color = "\033[92m" # Green for best + elif normalized_val == 1.0: + color = "\033[91m" # Red for worst + else: + color = f"\033[9{int(2 + normalized_val * 3)}m" + + val_str = format(val, ".3f") + colored_val_str = f"{color}{val_str}\033[0m" + + row += center_colored_text( + colored_val_str, task_col_widths[t] + 2 + ) + + terminalreporter.write_line(row) + else: + terminalreporter.write_line("No harvested fixture data found yet.") + @pytest.fixture(scope="function") def mcmc_params_accurate() -> dict: @@ -44,3 +212,48 @@ def mcmc_params_accurate() -> dict: def mcmc_params_fast() -> dict: """Fixture for MCMC parameters for fast tests.""" return dict(num_chains=1, thin=1, warmup_steps=1) + + +# Pytest harvest xdist support. + + +# Define the folder in which temporary worker's results will be stored +RESULTS_PATH = Path('./.xdist_results/') +RESULTS_PATH.mkdir(exist_ok=True) + + +def pytest_harvest_xdist_init(): + # reset the recipient folder + if RESULTS_PATH.exists(): + rmtree(RESULTS_PATH) + RESULTS_PATH.mkdir(exist_ok=False) + return True + + +def pytest_harvest_xdist_worker_dump(worker_id, session_items, fixture_store): + # persist session_items and fixture_store in the file system + with open(RESULTS_PATH / ('%s.pkl' % worker_id), 'wb') as f: + try: + pickle.dump((session_items, fixture_store), f) + except Exception as e: + warning( + "Error while pickling worker %s's harvested results: [%s] %s", + (worker_id, e.__class__, e), + ) + return True + + +def pytest_harvest_xdist_load(): + # restore the saved objects from file system + workers_saved_material = dict() + for pkl_file in RESULTS_PATH.glob('*.pkl'): + wid = pkl_file.stem + with pkl_file.open('rb') as f: + workers_saved_material[wid] = pickle.load(f) + return workers_saved_material + + +def pytest_harvest_xdist_cleanup(): + # delete all temporary pickle files + rmtree(RESULTS_PATH) + return True diff --git a/tests/mini_sbibm/__init__.py b/tests/mini_sbibm/__init__.py new file mode 100644 index 000000000..92b408c75 --- /dev/null +++ b/tests/mini_sbibm/__init__.py @@ -0,0 +1,37 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see +# NOTE: This is inspired by the sbibm-package + +from .base_task import Task +from .gaussian_linear import GaussianLinear +from .linear_mvg import LinearMVG2d +from .slcp import Slcp +from .two_moons import TwoMoons + +TASKS = { + "two_moons": TwoMoons, + "linear_mvg_2d": LinearMVG2d, + "gaussian_linear": GaussianLinear, + "slcp": Slcp, +} + + +def get_task(name: str, *args, **kwargs) -> Task: + """ + Retrieve a task instance based on the given name. + + Args: + name (str): The name of the task to retrieve. + Possible values are "two_moons", "linear_mvg_2d", + "gaussian_linear", and "slcp". + + Returns: + object: An instance of the corresponding task class. + + Raises: + ValueError: If the provided task name is unknown. + """ + try: + return TASKS[name](*args, **kwargs) + except KeyError as err: + raise ValueError(f"Unknown task {name}") from err diff --git a/tests/mini_sbibm/base_task.py b/tests/mini_sbibm/base_task.py new file mode 100644 index 000000000..3a18a6633 --- /dev/null +++ b/tests/mini_sbibm/base_task.py @@ -0,0 +1,123 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see +# NOTE: This is inspired by the sbibm-package + +import os +from abc import ABC, abstractmethod +from typing import Callable + +import torch +from torch.distributions import Distribution + +PATH = os.path.dirname(__file__) + + +class Task(ABC): + """ + Abstract base class for a task in the SBI benchmark. + + Args: + name (str): The name of the task. + """ + + def __init__(self, name: str): + self.name = name + + @abstractmethod + def theta_dim(self) -> int: + """ + Returns the dimensionality of the parameter space. + + Returns: + int: The dimensionality of the parameter space. + """ + pass + + @abstractmethod + def x_dim(self) -> int: + """ + Returns the dimensionality of the observation space. + + Returns: + int: The dimensionality of the observation space. + """ + pass + + @abstractmethod + def get_prior(self) -> Distribution: + """ + Returns the prior distribution over parameters. + + Returns: + Distribution: The prior distribution. + """ + pass + + @abstractmethod + def get_simulator(self) -> Callable: + """ + Returns the simulator function. + + Returns: + Callable: The simulator function. + """ + pass + + def get_data(self, num_sims: int): + """ + Generates data by sampling from the prior and simulating observations. + + Args: + num_sims (int): The number of simulations to run. + + Returns: + tuple: A tuple containing the sampled parameters and simulated observations. + """ + thetas = self.get_prior().sample((num_sims,)) + xs = self.get_simulator()(thetas) + return thetas, xs + + def get_observation(self, idx: int) -> torch.Tensor: + """ + Loads a specific observation from file. + + Args: + idx (int): The index of the observation to load. + + Returns: + torch.Tensor: The loaded observation. + """ + x_o = torch.load( + PATH + os.sep + "files" + os.sep + f"{self.name}{os.sep}x_o_{idx}.pt" + ) + return x_o + + def get_true_parameters(self, idx: int) -> torch.Tensor: + """ + Loads the true parameters for a specific observation from file. + + Args: + idx (int): The index of the parameters to load. + + Returns: + torch.Tensor: The loaded true parameters. + """ + theta = torch.load( + PATH + os.sep + "files" + os.sep + f"{self.name}{os.sep}theta_{idx}.pt" + ) + return theta + + def get_reference_posterior_samples(self, idx: int) -> torch.Tensor: + """ + Loads reference posterior samples for a specific observation from file. + + Args: + idx (int): The index of the posterior samples to load. + + Returns: + torch.Tensor: The loaded posterior samples. + """ + posterior_samples = torch.load( + PATH + os.sep + "files" + os.sep + f"{self.name}{os.sep}samples_{idx}.pt" + ) + return posterior_samples diff --git a/tests/mini_sbibm/files/slcp/samples_1.pt b/tests/mini_sbibm/files/slcp/samples_1.pt new file mode 100644 index 000000000..d0a1bd771 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_1.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_10.pt b/tests/mini_sbibm/files/slcp/samples_10.pt new file mode 100644 index 000000000..0de7efad6 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_10.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_2.pt b/tests/mini_sbibm/files/slcp/samples_2.pt new file mode 100644 index 000000000..f642597d3 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_2.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_3.pt b/tests/mini_sbibm/files/slcp/samples_3.pt new file mode 100644 index 000000000..640bcd1dc Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_3.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_4.pt b/tests/mini_sbibm/files/slcp/samples_4.pt new file mode 100644 index 000000000..1397dc02b Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_4.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_5.pt b/tests/mini_sbibm/files/slcp/samples_5.pt new file mode 100644 index 000000000..f2e8c35f1 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_5.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_6.pt b/tests/mini_sbibm/files/slcp/samples_6.pt new file mode 100644 index 000000000..091bb8143 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_6.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_7.pt b/tests/mini_sbibm/files/slcp/samples_7.pt new file mode 100644 index 000000000..edcbe6596 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_7.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_8.pt b/tests/mini_sbibm/files/slcp/samples_8.pt new file mode 100644 index 000000000..ec1cd0392 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_8.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_9.pt b/tests/mini_sbibm/files/slcp/samples_9.pt new file mode 100644 index 000000000..38b14665e Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_9.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_1.pt b/tests/mini_sbibm/files/slcp/theta_o_1.pt new file mode 100644 index 000000000..ab4dc7bae Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_1.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_10.pt b/tests/mini_sbibm/files/slcp/theta_o_10.pt new file mode 100644 index 000000000..4056ae9de Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_10.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_2.pt b/tests/mini_sbibm/files/slcp/theta_o_2.pt new file mode 100644 index 000000000..529388c43 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_2.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_3.pt b/tests/mini_sbibm/files/slcp/theta_o_3.pt new file mode 100644 index 000000000..97e333ced Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_3.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_4.pt b/tests/mini_sbibm/files/slcp/theta_o_4.pt new file mode 100644 index 000000000..27f22b885 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_4.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_5.pt b/tests/mini_sbibm/files/slcp/theta_o_5.pt new file mode 100644 index 000000000..64c3c77d4 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_5.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_6.pt b/tests/mini_sbibm/files/slcp/theta_o_6.pt new file mode 100644 index 000000000..607ced052 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_6.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_7.pt b/tests/mini_sbibm/files/slcp/theta_o_7.pt new file mode 100644 index 000000000..7da5b387b Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_7.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_8.pt b/tests/mini_sbibm/files/slcp/theta_o_8.pt new file mode 100644 index 000000000..d4ccf87d0 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_8.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_9.pt b/tests/mini_sbibm/files/slcp/theta_o_9.pt new file mode 100644 index 000000000..38b6ecd7b Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_9.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_1.pt b/tests/mini_sbibm/files/slcp/x_o_1.pt new file mode 100644 index 000000000..f806b3ff8 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_1.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_10.pt b/tests/mini_sbibm/files/slcp/x_o_10.pt new file mode 100644 index 000000000..ddfe6c805 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_10.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_2.pt b/tests/mini_sbibm/files/slcp/x_o_2.pt new file mode 100644 index 000000000..7c7f92274 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_2.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_3.pt b/tests/mini_sbibm/files/slcp/x_o_3.pt new file mode 100644 index 000000000..de8576ce5 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_3.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_4.pt b/tests/mini_sbibm/files/slcp/x_o_4.pt new file mode 100644 index 000000000..a60abaf24 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_4.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_5.pt b/tests/mini_sbibm/files/slcp/x_o_5.pt new file mode 100644 index 000000000..a8f169e89 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_5.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_6.pt b/tests/mini_sbibm/files/slcp/x_o_6.pt new file mode 100644 index 000000000..56d43231f Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_6.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_7.pt b/tests/mini_sbibm/files/slcp/x_o_7.pt new file mode 100644 index 000000000..d88572ab9 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_7.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_8.pt b/tests/mini_sbibm/files/slcp/x_o_8.pt new file mode 100644 index 000000000..35e6f5d21 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_8.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_9.pt b/tests/mini_sbibm/files/slcp/x_o_9.pt new file mode 100644 index 000000000..89211253f Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_9.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_1.pt b/tests/mini_sbibm/files/two_moons/samples_1.pt new file mode 100644 index 000000000..4758ad613 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_1.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_10.pt b/tests/mini_sbibm/files/two_moons/samples_10.pt new file mode 100644 index 000000000..22c9fb59a Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_10.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_2.pt b/tests/mini_sbibm/files/two_moons/samples_2.pt new file mode 100644 index 000000000..79c27d20f Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_2.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_3.pt b/tests/mini_sbibm/files/two_moons/samples_3.pt new file mode 100644 index 000000000..41300b513 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_3.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_4.pt b/tests/mini_sbibm/files/two_moons/samples_4.pt new file mode 100644 index 000000000..99164ddf8 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_4.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_5.pt b/tests/mini_sbibm/files/two_moons/samples_5.pt new file mode 100644 index 000000000..6c2d0a4bc Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_5.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_6.pt b/tests/mini_sbibm/files/two_moons/samples_6.pt new file mode 100644 index 000000000..7269844cb Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_6.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_7.pt b/tests/mini_sbibm/files/two_moons/samples_7.pt new file mode 100644 index 000000000..a75fddabb Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_7.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_8.pt b/tests/mini_sbibm/files/two_moons/samples_8.pt new file mode 100644 index 000000000..6f9f30df3 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_8.pt differ diff --git a/tests/mini_sbibm/files/two_moons/samples_9.pt b/tests/mini_sbibm/files/two_moons/samples_9.pt new file mode 100644 index 000000000..a8fe22661 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/samples_9.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_1.pt b/tests/mini_sbibm/files/two_moons/theta_o_1.pt new file mode 100644 index 000000000..5aa284381 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_1.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_10.pt b/tests/mini_sbibm/files/two_moons/theta_o_10.pt new file mode 100644 index 000000000..4d3c22045 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_10.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_2.pt b/tests/mini_sbibm/files/two_moons/theta_o_2.pt new file mode 100644 index 000000000..33a83c3c4 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_2.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_3.pt b/tests/mini_sbibm/files/two_moons/theta_o_3.pt new file mode 100644 index 000000000..fb783e50b Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_3.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_4.pt b/tests/mini_sbibm/files/two_moons/theta_o_4.pt new file mode 100644 index 000000000..3bc059efc Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_4.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_5.pt b/tests/mini_sbibm/files/two_moons/theta_o_5.pt new file mode 100644 index 000000000..c3aceb4dc Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_5.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_6.pt b/tests/mini_sbibm/files/two_moons/theta_o_6.pt new file mode 100644 index 000000000..fc93bd5ab Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_6.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_7.pt b/tests/mini_sbibm/files/two_moons/theta_o_7.pt new file mode 100644 index 000000000..16eb771e2 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_7.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_8.pt b/tests/mini_sbibm/files/two_moons/theta_o_8.pt new file mode 100644 index 000000000..a91ed39f6 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_8.pt differ diff --git a/tests/mini_sbibm/files/two_moons/theta_o_9.pt b/tests/mini_sbibm/files/two_moons/theta_o_9.pt new file mode 100644 index 000000000..9ce7a7821 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/theta_o_9.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_1.pt b/tests/mini_sbibm/files/two_moons/x_o_1.pt new file mode 100644 index 000000000..f9e20060e Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_1.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_10.pt b/tests/mini_sbibm/files/two_moons/x_o_10.pt new file mode 100644 index 000000000..fef3e1310 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_10.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_2.pt b/tests/mini_sbibm/files/two_moons/x_o_2.pt new file mode 100644 index 000000000..f55bf4cba Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_2.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_3.pt b/tests/mini_sbibm/files/two_moons/x_o_3.pt new file mode 100644 index 000000000..1b9b4e572 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_3.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_4.pt b/tests/mini_sbibm/files/two_moons/x_o_4.pt new file mode 100644 index 000000000..e0abd7c64 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_4.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_5.pt b/tests/mini_sbibm/files/two_moons/x_o_5.pt new file mode 100644 index 000000000..015495b4f Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_5.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_6.pt b/tests/mini_sbibm/files/two_moons/x_o_6.pt new file mode 100644 index 000000000..450b8e26f Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_6.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_7.pt b/tests/mini_sbibm/files/two_moons/x_o_7.pt new file mode 100644 index 000000000..10c547b71 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_7.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_8.pt b/tests/mini_sbibm/files/two_moons/x_o_8.pt new file mode 100644 index 000000000..b752885b7 Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_8.pt differ diff --git a/tests/mini_sbibm/files/two_moons/x_o_9.pt b/tests/mini_sbibm/files/two_moons/x_o_9.pt new file mode 100644 index 000000000..f9c9ad79f Binary files /dev/null and b/tests/mini_sbibm/files/two_moons/x_o_9.pt differ diff --git a/tests/mini_sbibm/gaussian_linear.py b/tests/mini_sbibm/gaussian_linear.py new file mode 100644 index 000000000..13b5f7dd4 --- /dev/null +++ b/tests/mini_sbibm/gaussian_linear.py @@ -0,0 +1,119 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see +# NOTE: This is inspired by the sbibm-package + +from functools import partial +from typing import Callable + +import torch +from torch.distributions import Distribution, MultivariateNormal + +from sbi.simulators.linear_gaussian import ( + diagonal_linear_gaussian, + true_posterior_linear_gaussian_mvn_prior, +) + +from .base_task import Task + + +class GaussianLinear(Task): + """ + Task for the Gaussian Linear model. + + This task uses a linear Gaussian model with a multivariate normal prior. + """ + + def __init__(self): + """ + Initializes the GaussianLinear task. + """ + self.simulator_scale = 0.1 + self.dim = 5 + super().__init__("gaussian_linear") + + def theta_dim(self) -> int: + """ + Returns the dimensionality of the parameter space. + + Returns: + int: The dimensionality of the parameter space. + """ + return self.dim + + def x_dim(self) -> int: + """ + Returns the dimensionality of the observation space. + + Returns: + int: The dimensionality of the observation space. + """ + return self.dim + + def get_reference_posterior_samples(self, idx: int) -> torch.Tensor: + """ + Generates reference posterior samples for a specific observation. + + Args: + idx (int): The index of the observation. + + Returns: + torch.Tensor: The reference posterior samples. + """ + x_o = self.get_observation(idx) + posterior = true_posterior_linear_gaussian_mvn_prior( + x_o, + torch.zeros(self.dim), + self.simulator_scale * torch.eye(self.dim), + torch.zeros(self.dim), + torch.eye(self.dim), + ) + + return posterior.sample((10_000,)) + + def get_true_parameters(self, idx: int) -> torch.Tensor: + """ + Generates the true parameters for a specific observation. + + Args: + idx (int): The index of the observation. + + Returns: + torch.Tensor: The true parameters. + """ + torch.manual_seed(idx) + return self.get_prior().sample() + + def get_observation(self, idx: int) -> torch.Tensor: + """ + Generates an observation for a specific set of true parameters. + + Args: + idx (int): The index of the observation. + + Returns: + torch.Tensor: The observation. + """ + theta_o = self.get_true_parameters(idx) + x_o = self.get_simulator()(theta_o[None, :])[0] + return x_o + + def get_prior(self) -> Distribution: + """ + Returns the prior distribution over parameters. + + Returns: + Distribution: The prior distribution. + """ + return MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim)) + + def get_simulator(self) -> Callable: + """ + Returns the simulator function. + + Returns: + Callable: The simulator function. + """ + return partial( + diagonal_linear_gaussian, + std=self.simulator_scale, + ) diff --git a/tests/mini_sbibm/linear_mvg.py b/tests/mini_sbibm/linear_mvg.py new file mode 100644 index 000000000..cb2f26153 --- /dev/null +++ b/tests/mini_sbibm/linear_mvg.py @@ -0,0 +1,120 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see +# NOTE: This is inspired by the sbibm-package + +from functools import partial +from typing import Callable + +import torch +from torch.distributions import Distribution, MultivariateNormal + +from sbi.simulators.linear_gaussian import ( + linear_gaussian, + true_posterior_linear_gaussian_mvn_prior, +) + +from .base_task import Task + + +class LinearMVG2d(Task): + """ + Task for the Linear Multivariate Gaussian (MVG) model in 2D. + + This task uses a linear Gaussian model with a multivariate normal prior. + """ + + def __init__(self): + """ + Initializes the LinearMVG2d task. + """ + self.likelihood_shift = torch.tensor([-1.0, 1.0]) + self.likelihood_cov = torch.tensor([[0.6, 0.5], [0.5, 0.6]]) + super().__init__("linear_mvg_2d") + + def theta_dim(self) -> int: + """ + Returns the dimensionality of the parameter space. + + Returns: + int: The dimensionality of the parameter space. + """ + return 2 + + def x_dim(self) -> int: + """ + Returns the dimensionality of the observation space. + + Returns: + int: The dimensionality of the observation space. + """ + return 2 + + def get_reference_posterior_samples(self, idx: int) -> torch.Tensor: + """ + Generates reference posterior samples for a specific observation. + + Args: + idx (int): The index of the observation. + + Returns: + torch.Tensor: The reference posterior samples. + """ + x_o = self.get_observation(idx) + posterior = true_posterior_linear_gaussian_mvn_prior( + x_o, + self.likelihood_shift, + self.likelihood_cov, + torch.zeros(2), + torch.eye(2), + ) + + return posterior.sample((10_000,)) + + def get_true_parameters(self, idx: int) -> torch.Tensor: + """ + Generates the true parameters for a specific observation. + + Args: + idx (int): The index of the observation. + + Returns: + torch.Tensor: The true parameters. + """ + torch.manual_seed(idx) + return self.get_prior().sample() + + def get_observation(self, idx: int) -> torch.Tensor: + """ + Generates an observation for a specific set of true parameters. + + Args: + idx (int): The index of the observation. + + Returns: + torch.Tensor: The observation. + """ + theta_o = self.get_true_parameters(idx) + x_o = self.get_simulator()(theta_o[None, :])[0] + return x_o + + def get_prior(self) -> Distribution: + """ + Returns the prior distribution over parameters. + + Returns: + Distribution: The prior distribution. + """ + return MultivariateNormal(torch.zeros(2), torch.eye(2)) + + def get_simulator(self) -> Callable: + """ + Returns the simulator function. + + Returns: + Callable: The simulator function. + """ + return partial( + linear_gaussian, + likelihood_shift=self.likelihood_shift, + likelihood_cov=self.likelihood_cov, + ) diff --git a/tests/mini_sbibm/slcp.py b/tests/mini_sbibm/slcp.py new file mode 100644 index 000000000..aea4d8314 --- /dev/null +++ b/tests/mini_sbibm/slcp.py @@ -0,0 +1,56 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see +# NOTE: This is inspired by the sbibm-package + +from typing import Callable + +import torch +from torch.distributions import Distribution, Independent, MultivariateNormal, Uniform + +from .base_task import Task + + +def simulator(theta, num_data=4): + num_samples = theta.shape[0] + + m = torch.stack((theta[:, [0]].squeeze(), theta[:, [1]].squeeze())).T + if m.dim() == 1: + m.unsqueeze_(0) + + s1 = theta[:, [2]].squeeze() ** 2 + s2 = theta[:, [3]].squeeze() ** 2 + rho = torch.nn.Tanh()(theta[:, [4]]).squeeze() + + S = torch.empty((num_samples, 2, 2)) + S[:, 0, 0] = s1**2 + S[:, 0, 1] = rho * s1 * s2 + S[:, 1, 0] = rho * s1 * s2 + S[:, 1, 1] = s2**2 + + # Add eps to diagonal to ensure PSD + eps = 0.000001 + S[:, 0, 0] += eps + S[:, 1, 1] += eps + + data_dist = MultivariateNormal(m, S) + xs = data_dist.sample((num_data,)) + xs = xs.permute(1, 0, 2) + + return xs.reshape(num_samples, num_data * 2) + + +class Slcp(Task): + def __init__(self): + super().__init__("slcp") + + def theta_dim(self) -> int: + return 5 + + def x_dim(self) -> int: + return 8 + + def get_prior(self) -> Distribution: + return Independent(Uniform(-3 * torch.ones(5), 3 * torch.ones(5)), 1) + + def get_simulator(self) -> Callable: + return simulator diff --git a/tests/mini_sbibm/two_moons.py b/tests/mini_sbibm/two_moons.py new file mode 100644 index 000000000..1316cc128 --- /dev/null +++ b/tests/mini_sbibm/two_moons.py @@ -0,0 +1,127 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see +# NOTE: This is inspired by the sbibm-package + +import math +from typing import Callable + +import torch +from torch.distributions import Distribution, Independent, Normal, Uniform + +from .base_task import Task + + +def _map_fun(parameters: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + """ + Maps the parameters and points to a new space using a rotation. + + Args: + parameters (torch.Tensor): The parameters for the mapping. + p (torch.Tensor): The points to be mapped. + + Returns: + torch.Tensor: The mapped points. + """ + ang = torch.tensor([-math.pi / 4.0]) + c = torch.cos(ang) + s = torch.sin(ang) + z0 = (c * parameters[:, 0] - s * parameters[:, 1]).reshape(-1, 1) + z1 = (s * parameters[:, 0] + c * parameters[:, 1]).reshape(-1, 1) + return p + torch.cat((-torch.abs(z0), z1), dim=1) + + +def simulator( + parameters: torch.Tensor, + r_loc: float = 0.1, + r_scale: float = 0.01, + a_low: float = -math.pi / 2.0, + a_high: float = math.pi / 2.0, + base_offset: float = 0.25, +) -> torch.Tensor: + """ + Simulator function for the Two Moons task. + + Args: + parameters (torch.Tensor): The parameters for the simulator. + r_loc (float, optional): The mean of the radius distribution. Defaults to 0.1. + r_scale (float, optional): The standard deviation of the radius distribution. + Defaults to 0.01. + a_low (float, optional): The lower bound of the angle distribution. Defaults to + -math.pi / 2.0. + a_high (float, optional): The upper bound of the angle distribution. Defaults to + math.pi / 2.0. + base_offset (float, optional): The base offset for the points. Defaults to 0.25. + + Returns: + torch.Tensor: The simulated data. + """ + num_samples = parameters.shape[0] + + a_dist = Uniform( + low=a_low, + high=a_high, + ) + a = a_dist.sample((num_samples, 1)) + + r_dist = Normal(r_loc, r_scale) + r = r_dist.sample((num_samples, 1)) + + p = torch.cat( + ( + torch.cos(a) * r + base_offset, + torch.sin(a) * r, + ), + dim=1, + ) + + return _map_fun(parameters, p) + + +class TwoMoons(Task): + """ + Task for the Two Moons model. + + This task uses a uniform prior and a custom simulator. + """ + + def __init__(self): + """ + Initializes the TwoMoons task. + """ + super().__init__("two_moons") + + def theta_dim(self) -> int: + """ + Returns the dimensionality of the parameter space. + + Returns: + int: The dimensionality of the parameter space. + """ + return 2 + + def x_dim(self) -> int: + """ + Returns the dimensionality of the observation space. + + Returns: + int: The dimensionality of the observation space. + """ + return 2 + + def get_prior(self) -> Distribution: + """ + Returns the prior distribution over parameters. + + Returns: + Distribution: The prior distribution. + """ + return Independent(Uniform(-torch.ones(2), torch.ones(2)), 1) + + def get_simulator(self) -> Callable: + """ + Returns the simulator function. + + Returns: + Callable: The simulator function. + """ + return simulator diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py index cd275a7b0..3b330a91c 100644 --- a/tests/user_input_checks_test.py +++ b/tests/user_input_checks_test.py @@ -85,7 +85,7 @@ def matrix_simulator(theta): # Set default tensor locally to reach tensors in fixtures. -torch.set_default_tensor_type(torch.FloatTensor) +torch.set_default_dtype(torch.float32) @pytest.mark.parametrize(