8. Saving & Loading Models#
BayesFlow models are saved and loaded through Keras 3, which handles serialization of the full model graph — architecture, weights, and configuration — into a single .keras file.
What is saved? When you save an approximator, the entire object is persisted: the inference network, the summary network (if any), and the adapter that defines how your data is preprocessed. You do not need to recreate or re-attach anything after loading.
This page covers three scenarios:
Automatic checkpointing — saving after every training epoch via the workflow API
Manual saving — explicitly saving a trained approximator to disk
Loading — restoring a saved model for inference or continued training
For custom networks, see Custom Networks & Serialization below.
8.1. Saving Models#
8.1.1. During training: automatic checkpointing#
Pass checkpoint_filepath and checkpoint_name when constructing a workflow. BayesFlow will then save the model automatically at the end of every training epoch.
import numpy as np
import bayesflow as bf
# Minimal simulator used throughout this notebook
def prior():
return {"theta": np.random.normal(size=2).astype("float32")}
def likelihood(theta):
mu = theta[0]
sigma = float(np.exp(theta[1] * 0.5))
return {"x": np.random.normal(mu, sigma, size=(20, 1)).astype("float32")}
simulator = bf.make_simulator([prior, likelihood])
adapter = (
bf.Adapter()
.convert_dtype("float64", "float32")
.rename("x", "summary_variables")
.rename("theta", "inference_variables")
)
workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
summary_network=bf.networks.TimeSeriesTransformer(summary_dim=16),
inference_network=bf.networks.FlowMatching(),
checkpoint_filepath="checkpoints", # directory to write into
checkpoint_name="my_model", # file will be: checkpoints/my_model.keras
)
history = workflow.fit_online(epochs=5)
8.1.1.1. Saving only the best checkpoint#
Set save_best_only=True to keep only the checkpoint with the lowest validation loss, discarding worse epochs automatically.
workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
summary_network=bf.networks.TimeSeriesTransformer(summary_dim=16),
inference_network=bf.networks.FlowMatching(),
checkpoint_filepath="checkpoints",
checkpoint_name="my_model",
save_best_only=True # keeps the epoch with the lowest validation loss
)
When to avoid
save_best_only: Some loss functions — such as flow matching — produce inherently noisy per-epoch estimates. In those cases the “best” checkpoint may simply be a lucky noise fluctuation rather than the best-trained model. Prefersave_best_only=False(the default) unless you have a reliable validation signal.
8.1.2. After training: manual save#
If you did not configure checkpointing upfront, or if you are working with the lower-level approximator API directly, you can save at any point:
# Save via the workflow
workflow.approximator.save("checkpoints/my_model.keras")
# Or if you have an approximator directly (no workflow)
approximator = workflow.approximator
approximator.save("checkpoints/my_model.keras")
8.2. Loading Models#
Use keras.saving.load_model to restore a saved approximator. Always import bayesflow before loading — BayesFlow registers its custom objects with Keras on import, and without this the load will fail.
# Important: BayesFlow needs to be loaded before keras!
import keras
approximator = keras.saving.load_model("checkpoints/my_model.keras")
# Held-out observations used in the loading examples
new_data = simulator.sample(4)
# The approximator is fully ready for inference — adapter included
samples = approximator.sample(num_samples=1000, conditions=new_data)
8.2.1. Attaching a loaded model to a workflow#
If you want to use the high-level workflow utilities (e.g., workflow.sample, workflow.plot_posterior), attach the loaded approximator to an existing workflow instance:
# Restore the approximator
workflow.approximator = keras.saving.load_model("checkpoints/my_model.keras")
# All workflow-level methods work as normal
samples = workflow.sample(num_samples=1000, conditions=new_data)
Sanity check after loading: Changes to model architectures across BayesFlow versions can occasionally cause a model to load without error but produce incorrect outputs. After loading, always run a quick check — for example, verify output shapes or compare a few samples against known-good results from before saving.
8.3. What Gets Saved?#
The .keras format serializes the full approximator object: its architecture (layer graph and configuration), trained weights, and the adapter (your data-preprocessing pipeline). When you load a model, you get back an identical, ready-to-use object — no need to rebuild the adapter or rewire anything.
Under the hood, Keras converts the object to a JSON config and stores it alongside the weights in an HDF5-like archive. For this to work, every component of your approximator (networks, layers, the adapter) must be serializable — meaning it can be fully described by its class plus a dictionary of constructor arguments.
All built-in BayesFlow networks and adapters are serializable by default. If you write a custom network, you need to take one extra step to make it serializable — see Custom Networks & Serialization below.
8.3.1. Version pinning#
BayesFlow is under active development. Occasionally, a change to a model’s internal architecture can break loading of old checkpoints — sometimes silently, without raising an error. To protect against this:
Pin your BayesFlow version per project (
pip install bayesflow==x.y.z).Always run a sanity check on model outputs immediately after loading.
8.4. Custom Networks & Serialization#
If you implement a custom summary network (or any custom layer), you must make it serializable so BayesFlow can save and load it. This requires two things:
Decorate the class with
@serializable— registers the class under a unique name in Keras’s object registry.Override
get_config()— returns the constructor arguments needed to reconstruct the object from scratch.
8.4.1. Minimal example#
from bayesflow.utils.serialization import serializable, serialize
@serializable("my_project") # unique namespace, e.g. your project name
class GRUBottleneck(bf.networks.SummaryNetwork):
def __init__(self, summary_dim=8, **kwargs):
super().__init__(**kwargs)
self.summary_dim = summary_dim
self.gru = keras.layers.GRU(64)
self.dense = keras.layers.Dense(summary_dim)
def call(self, time_series, **kwargs):
"""Compress (batch, T, d) time series into (batch, summary_dim) summaries."""
h = self.gru(time_series, **kwargs)
return self.dense(h)
def get_config(self):
base = super().get_config()
config = {"summary_dim": self.summary_dim}
return base | serialize(config)
After this, approximator.save(...) and keras.saving.load_model(...) will work exactly as with built-in networks — as long as import bayesflow (which triggers registration) appears before the load call.
8.4.2. The @serializable decorator#
@serializable(name) registers the class in Keras’s global object registry under the given name. Choose a namespaced string — something like "my_project.networks" — to avoid collisions with other packages that might register a class under the same name.
8.4.3. The get_config() method#
get_config() must return a plain dictionary that contains all constructor arguments needed to recreate the object. Call serialize(config) on any values that are themselves Keras objects or BayesFlow components (e.g., a sub-layer passed as __init__ argument) — this recursively converts them to their config dicts so they can be round-tripped through JSON.
def get_config(self):
base = super().get_config() # picks up parent class arguments
config = {"summary_dim": self.summary_dim}
return base | serialize(config) # serialize() handles nested Keras objects
If all your constructor arguments are plain Python primitives (int, float, str, bool, None), you do not need to call serialize — just merge the dicts directly.
8.5. Further Reading#
Keras: Serialization & Saving guide — in-depth coverage of the
.kerasformat, custom objects, andSavedModel