bayesflow.trainers module#

class bayesflow.trainers.Trainer(amortizer, generative_model=None, configurator=None, checkpoint_path=None, max_to_keep=3, default_lr=0.0005, skip_checks=False, memory=False, **kwargs)[source]#

Bases: object

This class connects a generative model (or, already simulated data from a model) with a configurator and a neural inference architecture for amortized inference (amortizer). A Trainer instance is responsible for optimizing the amortizer via various forms of simulation-based training.

At the very minimum, the trainer must be initialized with an amortizer instance, which is capable of processing the (configured) outputs of a generative model. A configurator will then process the outputs of the generative model and convert them into suitable inputs for the amortizer. Users can choose from a palette of default configurators or create their own configurators, essentially building a modularized pipeline GenerativeModel -> Configurator -> Amortizer. Most complex models will require custom configurators.

Notes

Currently, the trainer supports the following simulation-based training regimes, based on efficiency considerations:

  • Online training

    >>> trainer.train_online(epochs, iterations_per_epoch, batch_size, **kwargs)
    

    This training regime is optimal for fast generative models which can efficiently simulated data on-the-fly. In order for this training regime to be efficient, on-the-fly batch simulations should not take longer than 2-3 seconds.

  • Experience replay training

    >>> trainer.train_experience_replay(epochs, iterations_per_epoch, batch_size, **kwargs)
    

    This training regime is also good for fast generative models capable of efficiently simulating data on-the-fly. Compare to pure online training, this training will keep an experience replay buffer from which simulations are randomly sampled, so the networks will likely see some simulations multiple times.

  • Round-based training

    >>> trainer.train_rounds(rounds, sim_per_round, epochs, batch_size, **kwargs)
    

    This training regime is optimal for slow, but still reasonably performant generative models. In order for this training regime to be efficient, on-the-fly batch simulations should not take longer than 2-3 minutes.

    Note

    overfitting presents a danger when using small numbers of simulated data sets, so it is recommended to use some amount of regularization for the neural amortizer(s).

  • Offline training

    >>> trainer.train_offline(simulations_dict, epochs, batch_size, **kwargs)
    

    This training regime is optimal for very slow, external simulators, which take several minutes for a single simulation. It assumes that all training data has been already simulated and stored on disk.

    Warning

    Overfitting presents a danger when using a small simulated data set, so it is recommended to use some amount of regularization for the neural amortizer(s).

    Note

    For extremely slow simulators (i.e., more than an hour of a single simulation), the BayesFlow framework might not be the ideal choice and should probably be considered in combination with a black-box surrogate optimization method, such as Bayesian optimization.

__init__(amortizer, generative_model=None, configurator=None, checkpoint_path=None, max_to_keep=3, default_lr=0.0005, skip_checks=False, memory=False, **kwargs)[source]#

Creates a trainer which will use a generative model (or data simulated from it) to optimize a neural architecture (amortizer) for amortized posterior inference, likelihood inference, or both.

Parameters:
amortizerbayesflow.amortizers.Amortizer

The neural architecture to be optimized.

generative_modelbayesflow.forward_inference.GenerativeModel

A generative model returning a dictionary with randomly sampled parameters, data, and optional context

configuratorcallable or None, optional, default: None

A callable object transforming and combining the outputs of the generative model into inputs for a BayesFlow amortizer.

checkpoint_pathstring or None, optional, default: None

Optional file path for storing the trained amortizer, loss history and optional memory.

max_to_keepint, optional, default: 3

Number of checkpoints and loss history snapshots to keep.

default_lrfloat, optional, default: 0.0005

The default learning rate to use for default optimizers.

skip_checksbool, optional, default: False

If True, do not perform consistency checks, i.e., simulator runs and passed through nets

memorybool or bayesflow.SimulationMemory, optional, default: False

If True, store a pre-defined amount of simulations for later use (validation, etc.). If SimulationMemory instance provided, stores a reference to the instance. Otherwise the corresponding attribute will be set to None.

diagnose_latent2d(inputs=None, **kwargs)[source]#

Performs visual pre-inference diagnostics of latent space on either provided validation data (new simulations) or internal simulation memory. If inputs is not None, then diagnostics will be performed on the inputs, regardless whether the simulation_memory of the trainer is empty or not. If inputs is None, then the trainer will try to access is memory or raise a ConfigurationError.

Parameters:
inputsNone, list, or dict, optional, default: None

The optional inputs to use

Returns:
figplt.Figure

The figure object which can be readily saved to disk using fig.savefig().

Other Parameters:
conf_args

optional keyword arguments passed to the configurator

net_args

optional keyword arguments passed to the amortizer

plot_args

optional keyword arguments passed to plot_latent_space_2d

diagnose_sbc_histograms(inputs=None, n_samples=None, **kwargs)[source]#

Performs visual pre-inference diagnostics via simulation-based calibration (SBC) (new simulations) or internal simulation memory. If inputs is not None, then diagnostics will be performed on the inputs, regardless whether the simulation_memory of the trainer is empty or not. If inputs is None, then the trainer will try to access is memory or raise a ConfigurationError.

Parameters:
inputsNone, list or dict, optional, default: None

The optional inputs to use

n_samplesint or None, optional, default: None

The number of posterior samples to draw for each simulated data set. If None, the number will be heuristically determined so that n_sim / n_draws is approximately equal to 20

Returns:
figplt.Figure

The figure object which can be readily saved to disk using fig.savefig().

Other Parameters:
conf_args

optional keyword arguments passed to the configurator

net_args

optional keyword arguments passed to the amortizer

plot_args

optional keyword arguments passed to plot_sbc()

load_pretrained_network()[source]#

Attempts to load a pre-trained network if checkpoint path is provided and a checkpoint manager exists.

train_online(epochs, iterations_per_epoch, batch_size, save_checkpoint=True, optimizer=None, reuse_optimizer=False, early_stopping=False, use_autograph=True, validation_sims=None, **kwargs)[source]#

Trains an amortizer via online learning. Additional keyword arguments are passed to the generative mode, configurator, and amortizer.

Parameters:
epochsint

Number of epochs (and number of times a checkpoint is stored)

iterations_per_epochint

Number of batch simulations to perform per epoch

batch_sizeint

Number of simulations to perform at each backprop step

save_checkpointbool, default: True

A flag to decide whether to save checkpoints after each epoch, if a checkpoint_path provided during initialization, otherwise ignored.

optimizertf.keras.optimizer.Optimizer or None

Optimizer for the neural network. None will result in tf.keras.optimizers.Adam using a learning rate of 5e-4 and a cosine decay from 5e-4 to 0. A custom optimizer will override default learning rate and schedule settings.

reuse_optimizerbool, optional, default: False

A flag indicating whether the optimizer instance should be treated as persistent or not. If False, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored as ``self.optimizer` and re-used in further training runs.

early_stoppingbool, optional, default: False

Whether to use optional stopping or not during training. Could speed up training. Only works if validation_sims is not None, i.e., validation data has been provided.

use_autographbool, optional, default: True

Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.

validation_simsdict or None, optional, default: None

Simulations used as a “validation set”. If dict, will assume it’s the output of a generative model and try amortizer.compute_loss(configurator(validation_sims)) after each epoch. If int, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. If None (default), no validation set will be used.

Returns:
lossesdict or pandas.DataFrame

A dictionary storing the losses across epochs and iterations

Other Parameters:
model_args

optional kwargs passed to the generative model

val_model_args:

optional kwargs passed to the generative model for generating validation data. Only useful if type(validation_sims) is int.

conf_args

optional kwargs passed to the configurator before each backprop (update) step.

val_conf_args

optional kwargs passed to the configurator then configuring the validation data.

net_args

optional kwargs passed to the amortizer

early_stopping_args

optional kwargs passed to the EarlyStopper

train_offline(simulations_dict, epochs, batch_size, save_checkpoint=True, optimizer=None, reuse_optimizer=False, early_stopping=False, validation_sims=None, use_autograph=True, **kwargs)[source]#

Trains an amortizer via offline learning. Assume parameters, data and optional context have already been simulated (i.e., forward inference has been performed).

Parameters:
simulations_dictdict

A dictionary containing the simulated data / context, if using the default keys, the method expects at least the mandatory keys sim_data and prior_draws to be present

epochsint

Number of epochs (and number of times a checkpoint is stored)

batch_sizeint

Number of simulations to perform at each backpropagation step

save_checkpointbool, default: True

Determines whether to save checkpoints after each epoch, if a checkpoint_path provided during initialization, otherwise ignored.

optimizertf.keras.optimizer.Optimizer or None

Optimizer for the neural network. None will result in tf.keras.optimizers.Adam using a learning rate of 5e-4 and a cosine decay from 5e-4 to 0. A custom optimizer will override default learning rate and schedule settings.

reuse_optimizerbool, optional, default: False

A flag indicating whether the optimizer instance should be treated as persistent or not. If False, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored as self.optimizer and re-used in further training runs.

early_stoppingbool, optional, default: False

Whether to use optional stopping or not during training. Could speed up training. Only works if validation_sims is not None, i.e., validation data has been provided.

use_autographbool, optional, default: True

Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.

validation_simsdict, int, or None, optional, default: None

Simulations used as a “validation set”. If dict, will assume it’s the output of a generative model and try amortizer.compute_loss(configurator(validation_sims)) after each epoch. If int, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. If None (default), no validation set will be used.

Returns:
lossesdict or pandas.DataFrame

A dictionary or a data frame storing the losses across epochs and iterations

Other Parameters:
val_model_args

optional kwargs passed to the generative model for generating validation data. Only useful if type(validation_sims) is int.

conf_args

optional kwargs passed to the configurator before each backprop (update) step.

val_conf_args

optional kwargs passed to the configurator then configuring the validation data.

net_args

optional kwargs passed to the amortizer

early_stopping_args

optional kwargs passed to the EarlyStopper

train_from_presimulation(presimulation_path, optimizer, save_checkpoint=True, max_epochs=None, reuse_optimizer=False, custom_loader=None, early_stopping=False, validation_sims=None, use_autograph=True, **kwargs)[source]#

Trains an amortizer via a modified form of offline training.

Like regular offline training, it assumes that parameters, data and optional context have already been simulated (i.e., forward inference has been performed).

Also like regular offline training, it is faster than online training in scenarios where simulations are slow. Unlike regular offline training, it uses each batch from the presimulated dataset only once during training, if not otherwise specified by a higher maximal number of epochs. Then, presimulated data is reused in a cyclic manner to achieve the desired number of epochs. A larger presimulated dataset is therefore required than for offline training, and the increase in speed gained by loading simulations instead of generating them on the fly comes at a cost: a large presimulated dataset takes up a large amount of hard drive space.

Parameters:
presimulation_pathstr

File path to the folder containing the files from the precomputed simulation. Ideally generated using a GenerativeModel’s presimulate_and_save method, otherwise must match the structure produced by that method. Each file contains the data for one epoch (i.e. a number of batches), and must be compatible with the custom_loader provided. The custom_loader must read each file into a collection (either a dictionary or a list) of simulation_dict objects. This is easily achieved with the pickle library: if the files were generated from collections of simulation_dict objects using pickle.dump, the _default_loader (default for custom_load) will load them using pickle.load. Training parameters like number of iterations and batch size are inferred from the files during training.

optimizertf.keras.optimizer.Optimizer

Optimizer for the neural network training. Since for this training, it is impossible to guess the number of iterations beforehead, an optimizer must be provided.

save_checkpointbool, optional, defaultTrue

Determines whether to save checkpoints after each epoch, if a checkpoint_path provided during initialization, otherwise ignored.

max_epochsint or None, optional, default: None

An optional parameter to limit or extend the number of epochs. If number of epochs is larger than the files of the dataset, presimulations will be reused.

reuse_optimizerbool, optional, default: False

A flag indicating whether the optimizer instance should be treated as persistent or not. If False, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored as self.optimizer and re-used in further training runs.

custom_loadercallable, optional, default: self._default_loader

Must take a string file_path as an input and output a collection (dictionary or list) of simulation_dict objects. A simulation_dict has the keys prior_non_batchable_context, prior_batchable_context, prior_draws, sim_non_batchable_context, sim_batchable_context, and sim_data. Here, prior_draws and sim_data must have actual data as values, the rest are optional.

early_stoppingbool, optional, default: False

Whether to use optional stopping or not during training. Could speed up training.

validation_simsdict, int, or None, optional, default: None

Simulations used as a validation set. If dict, will assume it’s the output of a generative model and try amortizer.compute_loss(configurator(validation_sims)) after each epoch. If int, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. If None (default), no validation set will be used.

use_autographbool, optional, default: True

Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.

Returns:
lossesdict or pandas.DataFrame

A dictionary or a data frame storing the losses across epochs and iterations

Other Parameters:
conf_args

optional keyword arguments passed to the configurator

net_args

optional keyword arguments passed to the amortizer

train_experience_replay(epochs, iterations_per_epoch, batch_size, save_checkpoint=True, optimizer=None, reuse_optimizer=False, buffer_capacity=1000, early_stopping=False, use_autograph=True, validation_sims=None, **kwargs)[source]#

Trains the network(s) via experience replay using a memory replay buffer, as utilized in reinforcement learning. Additional keyword arguments are passed to the generative mode, configurator, and amortizer. Read below for signature.

Parameters:
epochsint

Number of epochs (and number of times a checkpoint is stored)

iterations_per_epochint

Number of batch simulations to perform per epoch

batch_sizeint

Number of simulations to perform at each backpropagation step.

save_checkpointbool, optional, default: True

A flag to decide whether to save checkpoints after each epoch, if a checkpoint_path provided during initialization, otherwise ignored.

optimizertf.keras.optimizer.Optimizer or None

Optimizer for the neural network. None will result in tf.keras.optimizers.Adam using a learning rate of 5e-4 and a cosine decay from 5e-4 to 0. A custom optimizer will override default learning rate and schedule settings.

reuse_optimizerbool, optional, default: False

A flag indicating whether the optimizer instance should be treated as persistent or not. If False, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored as self.optimizer and re-used in further training runs.

buffer_capacityint, optional, default: 1000

Max number of batches to store in buffer. For instance, if batch_size=32 and capacity_in_batches=1000, then the buffer will hold a maximum of 32 * 1000 = 32000 simulations. Be careful with memory! Important! Argument will be ignored if buffer has previously been initialized!

early_stoppingbool, optional, default: True

Whether to use optional stopping or not during training. Could speed up training. Only works if validation_sims is not None, i.e., validation data has been provided.

use_autographbool, optional, default: True

Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.

validation_simsdict or None, optional, default: None

Simulations used as a “validation set”. If dict, will assume it’s the output of a generative model and try amortizer.compute_loss(configurator(validation_sims)) after each epoch. If int, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. If None (default), no validation set will be used.

Returns:
lossesdict or pandas.DataFrame

A dictionary or a data frame storing the losses across epochs and iterations.

Other Parameters:
model_args

optional kwargs passed to the generative model

val_model_args

optional kwargs passed to the generative model for generating validation data. Only useful if type(validation_sims) is int.

conf_args

optional kwargs passed to the configurator before each backprop (update) step.

val_conf_args

optional kwargs passed to the configurator then configuring the validation data.

net_args

optional kwargs passed to the amortizer

early_stopping_args:

optional kwargs passed to the EarlyStopper

train_rounds(rounds, sim_per_round, epochs, batch_size, save_checkpoint=True, optimizer=None, reuse_optimizer=False, early_stopping=False, use_autograph=True, validation_sims=None, **kwargs)[source]#

Trains an amortizer via round-based learning. In each round, sim_per_round data sets are simulated from the generative model and added to the data sets simulated in previous round. Then, the networks are trained for epochs on the augmented set of data sets.

Note

Training time will increase from round to round, since the number of simulations increases correspondingly. The final round will then train the networks on rounds * sim_per_round data sets, so make sure this number does not eat up all available memory.

Parameters:
roundsint

Number of rounds to perform (outer loop)

sim_per_roundint

Number of simulations per round

epochsint

Number of epochs (and number of times a checkpoint is stored, inner loop) within a round.

batch_sizeint

Number of simulations to use at each backpropagation step

save_checkpointbool, optional, default: True

A flag to decide whether to save checkpoints after each epoch, if a checkpoint_path provided during initialization, otherwise ignored.

optimizertf.keras.optimizer.Optimizer or None

Optimizer for the neural network training. None will result in tf.keras.optimizers.Adam using a learning rate of 5e-4 and a cosine decay from 5e-4 to 0. A custom optimizer will override default learning rate and schedule settings.

reuse_optimizerbool, optional, default: False

A flag indicating whether the optimizer instance should be treated as persistent or not. If False, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored as self.optimizer and re-used in further training runs.

early_stoppingbool, optional, default: False

Whether to use optional stopping or not during training. Could speed up training. Only works if validation_sims is not None, i.e., validation data has been provided. Will be performed within rounds, not between rounds!

use_autographbool, optional, default: True

Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.

validation_simsdict or None, optional, default: None

Simulations used as a “validation set”. If dict, will assume it’s the output of a generative model and try amortizer.compute_loss(configurator(validation_sims)) after each epoch. If int, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. If None (default), no validation set will be used.

Returns:
lossesdict or pandas.DataFrame

A dictionary or a data frame storing the losses across epochs and iterations

Other Parameters:
model_args

optional kwargs passed to the generative model

val_model_args

optional kwargs passed to the generative model for generating validation data. Only useful if type(validation_sims) is int.

conf_args

optional kwargs passed to the configurator before each backprop (update) step.

val_conf_args

optional kwargs passed to the configurator then configuring the validation data.

net_args

optional kwargs passed to the amortizer

early_stopping_args

optional kwargs passed to the EarlyStopper

mmd_hypothesis_test(observed_data, reference_data=None, num_reference_simulations=1000, num_null_samples=100, bootstrap=False)[source]#

Performs a sampling-based hypothesis test for detecting Out-Of-Simulation (model misspecification).

Parameters:
observed_datanp.ndarray

Observed data, shape (num_observed, …)

reference_datanp.ndarray

Reference data representing samples from the well-specified model, shape (num_reference, …)

num_reference_simulationsint, default: 1000

Number of reference simulations (M) simulated from the trainer’s generative model if no reference_data are provided.

num_null_samplesint, default: 100

Number of draws from the MMD sampling distribution under the null hypothesis “the trainer’s generative model is well-specified”

bootstrapbool, default: False

If true, the reference data (see above) are bootstrapped for each sample from the MMD sampling distribution. If false, a new data set is simulated for computing each draw from the MMD sampling distribution.

Returns:
mmd_null_samplesnp.ndarray

samples from the H0 sampling distribution (“well-specified model”)

mmd_observedfloat

summary MMD estimate for the observed data sets