bayesflow.trainers module#
- class bayesflow.trainers.Trainer(amortizer, generative_model=None, configurator=None, checkpoint_path=None, max_to_keep=3, default_lr=0.0005, skip_checks=False, memory=False, **kwargs)[source]#
Bases:
object
This class connects a generative model (or, already simulated data from a model) with a configurator and a neural inference architecture for amortized inference (amortizer). A Trainer instance is responsible for optimizing the amortizer via various forms of simulation-based training.
At the very minimum, the trainer must be initialized with an amortizer instance, which is capable of processing the (configured) outputs of a generative model. A configurator will then process the outputs of the generative model and convert them into suitable inputs for the amortizer. Users can choose from a palette of default configurators or create their own configurators, essentially building a modularized pipeline GenerativeModel -> Configurator -> Amortizer. Most complex models will require custom configurators.
Notes
Currently, the trainer supports the following simulation-based training regimes, based on efficiency considerations:
Online training
>>> trainer.train_online(epochs, iterations_per_epoch, batch_size, **kwargs)
This training regime is optimal for fast generative models which can efficiently simulated data on-the-fly. In order for this training regime to be efficient, on-the-fly batch simulations should not take longer than 2-3 seconds.
Experience replay training
>>> trainer.train_experience_replay(epochs, iterations_per_epoch, batch_size, **kwargs)
This training regime is also good for fast generative models capable of efficiently simulating data on-the-fly. Compare to pure online training, this training will keep an experience replay buffer from which simulations are randomly sampled, so the networks will likely see some simulations multiple times.
Round-based training
>>> trainer.train_rounds(rounds, sim_per_round, epochs, batch_size, **kwargs)
This training regime is optimal for slow, but still reasonably performant generative models. In order for this training regime to be efficient, on-the-fly batch simulations should not take longer than 2-3 minutes.
Note
overfitting presents a danger when using small numbers of simulated data sets, so it is recommended to use some amount of regularization for the neural amortizer(s).
Offline training
>>> trainer.train_offline(simulations_dict, epochs, batch_size, **kwargs)
This training regime is optimal for very slow, external simulators, which take several minutes for a single simulation. It assumes that all training data has been already simulated and stored on disk.
Warning
Overfitting presents a danger when using a small simulated data set, so it is recommended to use some amount of regularization for the neural amortizer(s).
Note
For extremely slow simulators (i.e., more than an hour of a single simulation), the BayesFlow framework might not be the ideal choice and should probably be considered in combination with a black-box surrogate optimization method, such as Bayesian optimization.
- __init__(amortizer, generative_model=None, configurator=None, checkpoint_path=None, max_to_keep=3, default_lr=0.0005, skip_checks=False, memory=False, **kwargs)[source]#
Creates a trainer which will use a generative model (or data simulated from it) to optimize a neural architecture (amortizer) for amortized posterior inference, likelihood inference, or both.
- Parameters:
- amortizerbayesflow.amortizers.Amortizer
The neural architecture to be optimized.
- generative_modelbayesflow.forward_inference.GenerativeModel
A generative model returning a dictionary with randomly sampled parameters, data, and optional context
- configuratorcallable or None, optional, default: None
A callable object transforming and combining the outputs of the generative model into inputs for a BayesFlow amortizer.
- checkpoint_pathstring or None, optional, default: None
Optional file path for storing the trained amortizer, loss history and optional memory.
- max_to_keepint, optional, default: 3
Number of checkpoints and loss history snapshots to keep.
- default_lrfloat, optional, default: 0.0005
The default learning rate to use for default optimizers.
- skip_checksbool, optional, default: False
If True, do not perform consistency checks, i.e., simulator runs and passed through nets
- memorybool or bayesflow.SimulationMemory, optional, default: False
If
True
, store a pre-defined amount of simulations for later use (validation, etc.). If SimulationMemory instance provided, stores a reference to the instance. Otherwise the corresponding attribute will be set to None.
- diagnose_latent2d(inputs=None, **kwargs)[source]#
Performs visual pre-inference diagnostics of latent space on either provided validation data (new simulations) or internal simulation memory. If
inputs is not None
, then diagnostics will be performed on the inputs, regardless whether the simulation_memory of the trainer is empty or not. Ifinputs is None
, then the trainer will try to access is memory or raise a ConfigurationError.- Parameters:
- inputsNone, list, or dict, optional, default: None
The optional inputs to use
- Returns:
- figplt.Figure
The figure object which can be readily saved to disk using fig.savefig().
- Other Parameters:
- conf_args
optional keyword arguments passed to the configurator
- net_args
optional keyword arguments passed to the amortizer
- plot_args
optional keyword arguments passed to plot_latent_space_2d
- diagnose_sbc_histograms(inputs=None, n_samples=None, **kwargs)[source]#
Performs visual pre-inference diagnostics via simulation-based calibration (SBC) (new simulations) or internal simulation memory. If
inputs is not None
, then diagnostics will be performed on the inputs, regardless whether the simulation_memory of the trainer is empty or not. Ifinputs is None
, then the trainer will try to access is memory or raise a ConfigurationError.- Parameters:
- inputsNone, list or dict, optional, default: None
The optional inputs to use
- n_samplesint or None, optional, default: None
The number of posterior samples to draw for each simulated data set. If None, the number will be heuristically determined so that n_sim / n_draws is approximately equal to 20
- Returns:
- figplt.Figure
The figure object which can be readily saved to disk using fig.savefig().
- Other Parameters:
- conf_args
optional keyword arguments passed to the configurator
- net_args
optional keyword arguments passed to the amortizer
- plot_args
optional keyword arguments passed to plot_sbc()
- load_pretrained_network()[source]#
Attempts to load a pre-trained network if checkpoint path is provided and a checkpoint manager exists.
- train_online(epochs, iterations_per_epoch, batch_size, save_checkpoint=True, optimizer=None, reuse_optimizer=False, early_stopping=False, use_autograph=True, validation_sims=None, **kwargs)[source]#
Trains an amortizer via online learning. Additional keyword arguments are passed to the generative mode, configurator, and amortizer.
- Parameters:
- epochsint
Number of epochs (and number of times a checkpoint is stored)
- iterations_per_epochint
Number of batch simulations to perform per epoch
- batch_sizeint
Number of simulations to perform at each backprop step
- save_checkpointbool, default: True
A flag to decide whether to save checkpoints after each epoch, if a checkpoint_path provided during initialization, otherwise ignored.
- optimizertf.keras.optimizer.Optimizer or None
Optimizer for the neural network.
None
will result intf.keras.optimizers.Adam
using a learning rate of 5e-4 and a cosine decay from 5e-4 to 0. A custom optimizer will override default learning rate and schedule settings.- reuse_optimizerbool, optional, default: False
A flag indicating whether the optimizer instance should be treated as persistent or not. If
False
, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored as ``self.optimizer` and re-used in further training runs.- early_stoppingbool, optional, default: False
Whether to use optional stopping or not during training. Could speed up training. Only works if
validation_sims is not None
, i.e., validation data has been provided.- use_autographbool, optional, default: True
Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.
- validation_simsdict or None, optional, default: None
Simulations used as a “validation set”. If
dict
, will assume it’s the output of a generative model and tryamortizer.compute_loss(configurator(validation_sims))
after each epoch. Ifint
, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. IfNone
(default), no validation set will be used.
- Returns:
- lossesdict or pandas.DataFrame
A dictionary storing the losses across epochs and iterations
- Other Parameters:
- model_args
optional kwargs passed to the generative model
- val_model_args:
optional kwargs passed to the generative model for generating validation data. Only useful if
type(validation_sims) is int
.- conf_args
optional kwargs passed to the configurator before each backprop (update) step.
- val_conf_args
optional kwargs passed to the configurator then configuring the validation data.
- net_args
optional kwargs passed to the amortizer
- early_stopping_args
optional kwargs passed to the EarlyStopper
- train_offline(simulations_dict, epochs, batch_size, save_checkpoint=True, optimizer=None, reuse_optimizer=False, early_stopping=False, validation_sims=None, use_autograph=True, **kwargs)[source]#
Trains an amortizer via offline learning. Assume parameters, data and optional context have already been simulated (i.e., forward inference has been performed).
- Parameters:
- simulations_dictdict
A dictionary containing the simulated data / context, if using the default keys, the method expects at least the mandatory keys
sim_data
andprior_draws
to be present- epochsint
Number of epochs (and number of times a checkpoint is stored)
- batch_sizeint
Number of simulations to perform at each backpropagation step
- save_checkpointbool, default: True
Determines whether to save checkpoints after each epoch, if a checkpoint_path provided during initialization, otherwise ignored.
- optimizertf.keras.optimizer.Optimizer or None
Optimizer for the neural network.
None
will result intf.keras.optimizers.Adam
using a learning rate of 5e-4 and a cosine decay from 5e-4 to 0. A custom optimizer will override default learning rate and schedule settings.- reuse_optimizerbool, optional, default: False
A flag indicating whether the optimizer instance should be treated as persistent or not. If
False
, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored asself.optimizer
and re-used in further training runs.- early_stoppingbool, optional, default: False
Whether to use optional stopping or not during training. Could speed up training. Only works if
validation_sims is not None
, i.e., validation data has been provided.- use_autographbool, optional, default: True
Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.
- validation_simsdict, int, or None, optional, default: None
Simulations used as a “validation set”. If
dict
, will assume it’s the output of a generative model and tryamortizer.compute_loss(configurator(validation_sims))
after each epoch. Ifint
, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. IfNone
(default), no validation set will be used.
- Returns:
- losses
dict
orpandas.DataFrame
A dictionary or a data frame storing the losses across epochs and iterations
- losses
- Other Parameters:
- val_model_args
optional kwargs passed to the generative model for generating validation data. Only useful if
type(validation_sims) is int
.- conf_args
optional kwargs passed to the configurator before each backprop (update) step.
- val_conf_args
optional kwargs passed to the configurator then configuring the validation data.
- net_args
optional kwargs passed to the amortizer
- early_stopping_args
optional kwargs passed to the EarlyStopper
- train_from_presimulation(presimulation_path, optimizer, save_checkpoint=True, max_epochs=None, reuse_optimizer=False, custom_loader=None, early_stopping=False, validation_sims=None, use_autograph=True, **kwargs)[source]#
Trains an amortizer via a modified form of offline training.
Like regular offline training, it assumes that parameters, data and optional context have already been simulated (i.e., forward inference has been performed).
Also like regular offline training, it is faster than online training in scenarios where simulations are slow. Unlike regular offline training, it uses each batch from the presimulated dataset only once during training, if not otherwise specified by a higher maximal number of epochs. Then, presimulated data is reused in a cyclic manner to achieve the desired number of epochs. A larger presimulated dataset is therefore required than for offline training, and the increase in speed gained by loading simulations instead of generating them on the fly comes at a cost: a large presimulated dataset takes up a large amount of hard drive space.
- Parameters:
- presimulation_pathstr
File path to the folder containing the files from the precomputed simulation. Ideally generated using a GenerativeModel’s presimulate_and_save method, otherwise must match the structure produced by that method. Each file contains the data for one epoch (i.e. a number of batches), and must be compatible with the custom_loader provided. The custom_loader must read each file into a collection (either a dictionary or a list) of simulation_dict objects. This is easily achieved with the pickle library: if the files were generated from collections of simulation_dict objects using pickle.dump, the _default_loader (default for custom_load) will load them using pickle.load. Training parameters like number of iterations and batch size are inferred from the files during training.
- optimizertf.keras.optimizer.Optimizer
Optimizer for the neural network training. Since for this training, it is impossible to guess the number of iterations beforehead, an optimizer must be provided.
- save_checkpointbool, optional, defaultTrue
Determines whether to save checkpoints after each epoch, if a checkpoint_path provided during initialization, otherwise ignored.
- max_epochsint or None, optional, default: None
An optional parameter to limit or extend the number of epochs. If number of epochs is larger than the files of the dataset, presimulations will be reused.
- reuse_optimizerbool, optional, default: False
A flag indicating whether the optimizer instance should be treated as persistent or not. If
False
, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored asself.optimizer
and re-used in further training runs.- custom_loadercallable, optional, default: self._default_loader
Must take a string file_path as an input and output a collection (dictionary or list) of simulation_dict objects. A simulation_dict has the keys
prior_non_batchable_context
,prior_batchable_context
,prior_draws
,sim_non_batchable_context
,sim_batchable_context
, andsim_data
. Here,prior_draws
andsim_data
must have actual data as values, the rest are optional.- early_stoppingbool, optional, default: False
Whether to use optional stopping or not during training. Could speed up training.
- validation_simsdict, int, or None, optional, default: None
Simulations used as a validation set. If
dict
, will assume it’s the output of a generative model and tryamortizer.compute_loss(configurator(validation_sims))
after each epoch. Ifint
, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. IfNone
(default), no validation set will be used.- use_autographbool, optional, default: True
Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.
- Returns:
- losses
dict
orpandas.DataFrame
A dictionary or a data frame storing the losses across epochs and iterations
- losses
- Other Parameters:
- conf_args
optional keyword arguments passed to the configurator
- net_args
optional keyword arguments passed to the amortizer
- train_experience_replay(epochs, iterations_per_epoch, batch_size, save_checkpoint=True, optimizer=None, reuse_optimizer=False, buffer_capacity=1000, early_stopping=False, use_autograph=True, validation_sims=None, **kwargs)[source]#
Trains the network(s) via experience replay using a memory replay buffer, as utilized in reinforcement learning. Additional keyword arguments are passed to the generative mode, configurator, and amortizer. Read below for signature.
- Parameters:
- epochsint
Number of epochs (and number of times a checkpoint is stored)
- iterations_per_epochint
Number of batch simulations to perform per epoch
- batch_sizeint
Number of simulations to perform at each backpropagation step.
- save_checkpointbool, optional, default: True
A flag to decide whether to save checkpoints after each epoch, if a
checkpoint_path
provided during initialization, otherwise ignored.- optimizertf.keras.optimizer.Optimizer or None
Optimizer for the neural network.
None
will result intf.keras.optimizers.Adam
using a learning rate of 5e-4 and a cosine decay from 5e-4 to 0. A custom optimizer will override default learning rate and schedule settings.- reuse_optimizerbool, optional, default: False
A flag indicating whether the optimizer instance should be treated as persistent or not. If
False
, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored asself.optimizer
and re-used in further training runs.- buffer_capacityint, optional, default: 1000
Max number of batches to store in buffer. For instance, if
batch_size=32
andcapacity_in_batches=1000
, then the buffer will hold a maximum of 32 * 1000 = 32000 simulations. Be careful with memory! Important! Argument will be ignored if buffer has previously been initialized!- early_stoppingbool, optional, default: True
Whether to use optional stopping or not during training. Could speed up training. Only works if
validation_sims is not None
, i.e., validation data has been provided.- use_autographbool, optional, default: True
Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.
- validation_simsdict or None, optional, default: None
Simulations used as a “validation set”. If
dict
, will assume it’s the output of a generative model and tryamortizer.compute_loss(configurator(validation_sims))
after each epoch. Ifint
, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. IfNone
(default), no validation set will be used.
- Returns:
- losses
dict
orpandas.DataFrame
A dictionary or a data frame storing the losses across epochs and iterations.
- losses
- Other Parameters:
- model_args
optional kwargs passed to the generative model
- val_model_args
optional kwargs passed to the generative model for generating validation data. Only useful if
type(validation_sims) is int
.- conf_args
optional kwargs passed to the configurator before each backprop (update) step.
- val_conf_args
optional kwargs passed to the configurator then configuring the validation data.
- net_args
optional kwargs passed to the amortizer
- early_stopping_args:
optional kwargs passed to the EarlyStopper
- train_rounds(rounds, sim_per_round, epochs, batch_size, save_checkpoint=True, optimizer=None, reuse_optimizer=False, early_stopping=False, use_autograph=True, validation_sims=None, **kwargs)[source]#
Trains an amortizer via round-based learning. In each round,
sim_per_round
data sets are simulated from the generative model and added to the data sets simulated in previous round. Then, the networks are trained forepochs
on the augmented set of data sets.Note
Training time will increase from round to round, since the number of simulations increases correspondingly. The final round will then train the networks on
rounds * sim_per_round
data sets, so make sure this number does not eat up all available memory.- Parameters:
- roundsint
Number of rounds to perform (outer loop)
- sim_per_roundint
Number of simulations per round
- epochsint
Number of epochs (and number of times a checkpoint is stored, inner loop) within a round.
- batch_sizeint
Number of simulations to use at each backpropagation step
- save_checkpointbool, optional, default: True
A flag to decide whether to save checkpoints after each epoch, if a checkpoint_path provided during initialization, otherwise ignored.
- optimizertf.keras.optimizer.Optimizer or None
Optimizer for the neural network training.
None
will result intf.keras.optimizers.Adam
using a learning rate of 5e-4 and a cosine decay from 5e-4 to 0. A custom optimizer will override default learning rate and schedule settings.- reuse_optimizerbool, optional, default: False
A flag indicating whether the optimizer instance should be treated as persistent or not. If
False
, the optimizer and its states are not stored after training has finished. Otherwise, the optimizer will be stored asself.optimizer
and re-used in further training runs.- early_stoppingbool, optional, default: False
Whether to use optional stopping or not during training. Could speed up training. Only works if
validation_sims is not None
, i.e., validation data has been provided. Will be performed within rounds, not between rounds!- use_autographbool, optional, default: True
Whether to use autograph for the backprop step. Could lead to enormous speed-ups but could also be harder to debug.
- validation_simsdict or None, optional, default: None
Simulations used as a “validation set”. If
dict
, will assume it’s the output of a generative model and tryamortizer.compute_loss(configurator(validation_sims))
after each epoch. Ifint
, will assume it’s the number of sims to generate from the generative model before starting training. Only considered if a generative model has been provided during initialization. IfNone
(default), no validation set will be used.
- Returns:
- losses
dict
orpandas.DataFrame
A dictionary or a data frame storing the losses across epochs and iterations
- losses
- Other Parameters:
- model_args
optional kwargs passed to the generative model
- val_model_args
optional kwargs passed to the generative model for generating validation data. Only useful if
type(validation_sims) is int
.- conf_args
optional kwargs passed to the configurator before each backprop (update) step.
- val_conf_args
optional kwargs passed to the configurator then configuring the validation data.
- net_args
optional kwargs passed to the amortizer
- early_stopping_args
optional kwargs passed to the EarlyStopper
- mmd_hypothesis_test(observed_data, reference_data=None, num_reference_simulations=1000, num_null_samples=100, bootstrap=False)[source]#
Performs a sampling-based hypothesis test for detecting Out-Of-Simulation (model misspecification).
- Parameters:
- observed_datanp.ndarray
Observed data, shape (num_observed, …)
- reference_datanp.ndarray
Reference data representing samples from the well-specified model, shape (num_reference, …)
- num_reference_simulationsint, default: 1000
Number of reference simulations (M) simulated from the trainer’s generative model if no reference_data are provided.
- num_null_samplesint, default: 100
Number of draws from the MMD sampling distribution under the null hypothesis “the trainer’s generative model is well-specified”
- bootstrapbool, default: False
If true, the reference data (see above) are bootstrapped for each sample from the MMD sampling distribution. If false, a new data set is simulated for computing each draw from the MMD sampling distribution.
- Returns:
- mmd_null_samplesnp.ndarray
samples from the H0 sampling distribution (“well-specified model”)
- mmd_observedfloat
summary MMD estimate for the observed data sets