8. Saving & Loading Models#

BayesFlow models are saved and loaded through Keras 3, which handles serialization of the full model graph — architecture, weights, and configuration — into a single .keras file.

What is saved? When you save an approximator, the entire object is persisted: the inference network, the summary network (if any), and the adapter that defines how your data is preprocessed. You do not need to recreate or re-attach anything after loading.

This page covers three scenarios:

  1. Automatic checkpointing — saving after every training epoch via the workflow API

  2. Manual saving — explicitly saving a trained approximator to disk

  3. Loading — restoring a saved model for inference or continued training

For custom networks, see Custom Networks & Serialization below.

8.1. Saving Models#

8.1.1. During training: automatic checkpointing#

Pass checkpoint_filepath and checkpoint_name when constructing a workflow. BayesFlow will then save the model automatically at the end of every training epoch.

import numpy as np
import bayesflow as bf

# Minimal simulator used throughout this notebook
def prior():
    return {"theta": np.random.normal(size=2).astype("float32")}

def likelihood(theta):
    mu = theta[0]
    sigma = float(np.exp(theta[1] * 0.5))
    return {"x": np.random.normal(mu, sigma, size=(20, 1)).astype("float32")}

simulator = bf.make_simulator([prior, likelihood])

adapter = (
    bf.Adapter()
    .convert_dtype("float64", "float32")
    .rename("x", "summary_variables")
    .rename("theta", "inference_variables")
)
workflow = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    summary_network=bf.networks.TimeSeriesTransformer(summary_dim=16),
    inference_network=bf.networks.FlowMatching(),
    checkpoint_filepath="checkpoints",  # directory to write into
    checkpoint_name="my_model",         # file will be: checkpoints/my_model.keras
)

history = workflow.fit_online(epochs=5)

8.1.1.1. Saving only the best checkpoint#

Set save_best_only=True to keep only the checkpoint with the lowest validation loss, discarding worse epochs automatically.

workflow = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    summary_network=bf.networks.TimeSeriesTransformer(summary_dim=16),
    inference_network=bf.networks.FlowMatching(),
    checkpoint_filepath="checkpoints",
    checkpoint_name="my_model",
    save_best_only=True   # keeps the epoch with the lowest validation loss
)

When to avoid save_best_only: Some loss functions — such as flow matching — produce inherently noisy per-epoch estimates. In those cases the “best” checkpoint may simply be a lucky noise fluctuation rather than the best-trained model. Prefer save_best_only=False (the default) unless you have a reliable validation signal.

8.1.2. After training: manual save#

If you did not configure checkpointing upfront, or if you are working with the lower-level approximator API directly, you can save at any point:

# Save via the workflow
workflow.approximator.save("checkpoints/my_model.keras")

# Or if you have an approximator directly (no workflow)
approximator = workflow.approximator
approximator.save("checkpoints/my_model.keras")

8.2. Loading Models#

Use keras.saving.load_model to restore a saved approximator. Always import bayesflow before loading — BayesFlow registers its custom objects with Keras on import, and without this the load will fail.

# Important: BayesFlow needs to be loaded before keras!
import keras

approximator = keras.saving.load_model("checkpoints/my_model.keras")

# Held-out observations used in the loading examples
new_data = simulator.sample(4)

# The approximator is fully ready for inference — adapter included
samples = approximator.sample(num_samples=1000, conditions=new_data)

8.2.1. Attaching a loaded model to a workflow#

If you want to use the high-level workflow utilities (e.g., workflow.sample, workflow.plot_posterior), attach the loaded approximator to an existing workflow instance:

# Restore the approximator
workflow.approximator = keras.saving.load_model("checkpoints/my_model.keras")

# All workflow-level methods work as normal
samples = workflow.sample(num_samples=1000, conditions=new_data)

Sanity check after loading: Changes to model architectures across BayesFlow versions can occasionally cause a model to load without error but produce incorrect outputs. After loading, always run a quick check — for example, verify output shapes or compare a few samples against known-good results from before saving.

8.3. What Gets Saved?#

The .keras format serializes the full approximator object: its architecture (layer graph and configuration), trained weights, and the adapter (your data-preprocessing pipeline). When you load a model, you get back an identical, ready-to-use object — no need to rebuild the adapter or rewire anything.

Under the hood, Keras converts the object to a JSON config and stores it alongside the weights in an HDF5-like archive. For this to work, every component of your approximator (networks, layers, the adapter) must be serializable — meaning it can be fully described by its class plus a dictionary of constructor arguments.

All built-in BayesFlow networks and adapters are serializable by default. If you write a custom network, you need to take one extra step to make it serializable — see Custom Networks & Serialization below.

8.3.1. Version pinning#

BayesFlow is under active development. Occasionally, a change to a model’s internal architecture can break loading of old checkpoints — sometimes silently, without raising an error. To protect against this:

  • Pin your BayesFlow version per project (pip install bayesflow==x.y.z).

  • Always run a sanity check on model outputs immediately after loading.

8.4. Custom Networks & Serialization#

If you implement a custom summary network (or any custom layer), you must make it serializable so BayesFlow can save and load it. This requires two things:

  1. Decorate the class with @serializable — registers the class under a unique name in Keras’s object registry.

  2. Override get_config() — returns the constructor arguments needed to reconstruct the object from scratch.

8.4.1. Minimal example#

from bayesflow.utils.serialization import serializable, serialize

@serializable("my_project")  # unique namespace, e.g. your project name
class GRUBottleneck(bf.networks.SummaryNetwork):
    def __init__(self, summary_dim=8, **kwargs):
        super().__init__(**kwargs)
        self.summary_dim = summary_dim
        self.gru = keras.layers.GRU(64)
        self.dense = keras.layers.Dense(summary_dim)

    def call(self, time_series, **kwargs):
        """Compress (batch, T, d) time series into (batch, summary_dim) summaries."""
        h = self.gru(time_series, **kwargs)
        return self.dense(h)

    def get_config(self):
        base = super().get_config()
        config = {"summary_dim": self.summary_dim}
        return base | serialize(config)

After this, approximator.save(...) and keras.saving.load_model(...) will work exactly as with built-in networks — as long as import bayesflow (which triggers registration) appears before the load call.

8.4.2. The @serializable decorator#

@serializable(name) registers the class in Keras’s global object registry under the given name. Choose a namespaced string — something like "my_project.networks" — to avoid collisions with other packages that might register a class under the same name.

8.4.3. The get_config() method#

get_config() must return a plain dictionary that contains all constructor arguments needed to recreate the object. Call serialize(config) on any values that are themselves Keras objects or BayesFlow components (e.g., a sub-layer passed as __init__ argument) — this recursively converts them to their config dicts so they can be round-tripped through JSON.

def get_config(self):
    base = super().get_config()          # picks up parent class arguments
    config = {"summary_dim": self.summary_dim}
    return base | serialize(config)      # serialize() handles nested Keras objects

If all your constructor arguments are plain Python primitives (int, float, str, bool, None), you do not need to call serialize — just merge the dicts directly.

8.5. Further Reading#