Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FMPE not working even on simple tasks? #1374

Open
yangyang-pro opened this issue Jan 21, 2025 · 2 comments
Open

FMPE not working even on simple tasks? #1374

yangyang-pro opened this issue Jan 21, 2025 · 2 comments
Labels
question Further information is requested

Comments

@yangyang-pro
Copy link

Hi,

Thanks for your work on this nice package!

I have been trying the implemented FMPE method sbi.inference.FMPE on a few SBI tasks. With the default parameter settings, I found that the performance is terrible, compared to the results reported in the paper (Flow Matching for Scalable Simulation-Based Inference.

For example, on the two-moons task from sbibm, FMPE with either mlp or resnet backend trained on 10000 simulations often only achieved about 0.9 of c2st accuracy, or even worse.

This is the posterior samples of the trained FMPE conditioned on the first observation from sbibm:

Image

I am using the following code snippet. There are some hydra and wandb configurations, but basically I was using the original FMPE implementations in SBI and didn't change any parameters.

import logging
import os

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from sbi.analysis import pairplot
from sbi.inference import FMPE

import sbibm
import wandb
from sbibm.metrics import c2st


@hydra.main(version_base=None, config_path="./configs", config_name="train_fmpe_sbibm")
def train(cfg: DictConfig):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logging.info(f"Device: {device}")

    enable_wandb = cfg.wandb.enabled
    if enable_wandb:
        wandb.init(
            config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
            project=cfg.wandb.project,
            tags=cfg.wandb.tags,
            name=cfg.wandb.name,
            reinit=True,
        )

    task = sbibm.get_task(cfg.data.name)
    prior_gen = task.get_prior()
    prior_dist = task.get_prior_dist()
    simulator = task.get_simulator()
    observation = task.get_observation(num_observation=1)
    reference_samples = task.get_reference_posterior_samples(num_observation=1)

    theta_train = prior_gen(num_samples=cfg.data.num_training_samples)
    x_train = simulator(theta_train)

    theta_test = prior_gen(num_samples=cfg.data.num_test_samples)
    x_test = simulator(theta_test)

    model = FMPE(
        density_estimator="resnet",
        prior=prior_dist,
        device=device,
    )
    model.append_simulations(theta=theta_train.to(device), x=x_train.to(device)).train(
        force_first_round_loss=True
    )

    posterior = model.build_posterior()

    test_log_probs = posterior.log_prob_batched(
        theta_test.unsqueeze(0).to(device), x_test.to(device), norm_posterior=False
    )

    samples = posterior.sample(sample_shape=(len(reference_samples),), x=observation)
    c2st_accuracy = c2st(samples, reference_samples)

    if enable_wandb:
        wandb.run.summary["mean test log_prob"] = torch.mean(test_log_probs)
        wandb.run.summary["c2st"] = c2st_accuracy

    logging.info(f"mean test log_prob: {torch.mean(test_log_probs)}")
    logging.info(f"c2st: {c2st_accuracy}")

    fig, _ = pairplot(samples)
    if not os.path.exists("./results/figures"):
        os.makedirs("./results/figures")
    fig.suptitle(f"c2st={c2st_accuracy.item()}", fontsize=16)
    fig.savefig(
        "./results/figures/fmpe_sbibm_" + cfg.data.name + "_posterior_samples.png"
    )


if __name__ == "__main__":
    train()
@yangyang-pro yangyang-pro added the question Further information is requested label Jan 21, 2025
@janfb
Copy link
Contributor

janfb commented Jan 23, 2025

Thanks for reporting this!

Indeed, for the mini-sbibm we are working on, see #1335 , we observe similar performance on two-moons (with fewer simulations though).

We will have a look on the difference between the implementations here vs. in https://github.com/dingo-gw/flow-matching-posterior-estimation.

@yangyang-pro
Copy link
Author

Thanks for your response and good to know your ongoing work!

Just for more information, I also run the codebase in https://github.com/dingo-gw/flow-matching-posterior-estimation on two-moons. Its performance is also terrible (around 0.8 c2st accuracy) with 10000 simulations, which is much worse than the results in their paper. But it performs very well (about 0.55 c2st accuracy) with 100000 simulations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants