Afterglow provides your PyTorch models with uncertainty estimation capabilites. It's designed to work with any PyTorch model, with a minimum of fuss. It uses SWAG as its core uncertainty esitmation method.
With afterglow, you can transform code that trains point-estimating models into code that trains uncertainty-estimating models using a single line:
from afterglow import enable_swag
enable_swag(
model,
start_iteration=100 * len(train_dataloader), # start tracking at epoch 100
update_period_in_iters=len(train_dataloader), # update posterior every epoch
max_cols=20,
)
After training your model as usual, you can obtain uncertainty estimates like so:
mean, std = model.trajectory_tracker.predict_uncertainty(x, num_samples=30)
You can sample single instances of the model from the SWAG posterior:
model.trajectory_tracker.sample_state()
sample_at_x = model(x)
You can efficiently predict on an entire dataloader, drawing one sample for each pass over the dataset:
dataset_means, dataset_stds = model.trajectroy_tracker.predict_uncertainty_on_dataloader(
dataloader=dataloder, num_samples=30
)
If you pass a dataloader to enable_swag
, the SWAG batchnorm update step will be taken care of for you:
from afterglow import enable_swag
enable_swag(
model,
start_iteration=100 * len(train_dataloader),
update_period_in_iters=len(train_dataloader),
max_cols=20,
dataloader_for_batchnorm=train_dataloader, # now we'll do the bn update when we sample
)
You can speed online inference up by limiting the number of samples used to update batchnorm parameters:
from afterglow import enable_swag
enable_swag(
model,
start_iteration=100 * len(train_dataloader),
update_period_in_iters=len(train_dataloader),
max_cols=20,
dataloader_for_batchnorm=train_dataloader,
num_datapoints_for_bn_update=10, # now we'll only use 10 examples for the bn update
)
See the documentation, and the example app in /example
, for more!