import numpy as np
import bayesflow as bf

# Minimal simulator: 2 parameters, 20 i.i.d. observations
def prior():
    return dict(theta=np.random.normal(size=2).astype("float32"))

def likelihood(theta):
    mu, sigma = theta[0], float(np.exp(theta[1] * 0.5))
    return dict(x=np.random.normal(mu, sigma, size=(20, 1)).astype("float32"))

simulator = bf.make_simulator([prior, likelihood])

7. Workflows#

A workflow is a high-level wrapper around an approximator that manages the full inference pipeline: data generation, compilation, training, checkpointing, sampling, and diagnostics. For most use cases a workflow is the most convenient entry point — it handles the boilerplate so you can focus on your model.

BayesFlow provides three workflow classes:

Workflow

Use case

BasicWorkflow

Standard posterior / likelihood estimation (NPE, NLE, scoring rules)

EnsembleWorkflow

Train and query multiple approximators jointly

CompositionalWorkflow

Compositional / hierarchical inference with a DiffusionModel

All three live in bayesflow.workflows and are re-exported from the top-level bayesflow namespace. This page covers each in turn, then explains the training strategies (fit_online, fit_offline, fit_disk) and the diagnostic utilities common to all workflows.

7.1. BasicWorkflow#

BasicWorkflow is the standard starting point. At a minimum you provide a simulator and tell it which variables play which role; the workflow builds a default adapter, compiles the approximator, and is ready to train.

import bayesflow as bf

workflow = bf.BasicWorkflow(
    simulator=simulator,
    inference_variables=["theta"],
    summary_variables=["x"],
    inference_network=bf.networks.FlowMatching(),
    summary_network=bf.networks.TimeSeriesTransformer(summary_dim=16),
)

history = workflow.fit_online(epochs=5, batch_size=64, num_batches_per_epoch=50)
test_batch = simulator.sample(4)
x_obs = test_batch["x"].astype("float32")
theta = test_batch["theta"].astype("float32")
basic_workflow = workflow  # used later in CompositionalWorkflow.from_basic_workflow

7.1.1. What BasicWorkflow builds for you#

  • A default Adapter that routes inference_variables, summary_variables, and inference_conditions correctly.

  • An approximator (ContinuousApproximator or ScoringRuleApproximator depending on the inference network).

  • An Adam optimizer with a cosine decay schedule appropriate for the chosen training strategy.

You can override any of these by passing a custom adapter, optimizer, or a pre-built approximator via inference_network.

7.1.2. Key constructor arguments#

Argument

Default

Description

simulator

None

Source of training data

inference_variables

Variables to infer (e.g., ["theta"])

summary_variables

None

Variables to summarize with the summary network

inference_conditions

None

Variables fed directly to the inference network (not summarized)

inference_network

"coupling_flow"

Network or name string

summary_network

None

Summary network or name string

initial_learning_rate

5e-4

Starting LR for the optimizer

checkpoint_filepath

None

Directory for automatic checkpointing

checkpoint_name

"model"

File stem; saved as {name}.keras

save_best_only

False

Keep only the lowest-loss checkpoint

7.1.3. After training#

The trained approximator is always accessible at workflow.approximator. All approximator methods (.sample(), .log_prob(), etc.) work directly on it:

# High-level workflow interface
samples = workflow.sample(num_samples=1000, conditions={"x": x_obs})

# Equivalent — directly on the approximator
samples = workflow.approximator.sample(num_samples=1000, conditions={"x": x_obs})

To save the model, use the approximator (not the workflow):

workflow.approximator.save("my_model.keras")

7.2. EnsembleWorkflow#

EnsembleWorkflow trains multiple approximators jointly on the same data, then lets you query them individually or as a merged mixture. This is useful for uncertainty quantification over the inference network itself, or for comparing different network architectures on the same problem.

7.2.1. Construction modes#

Size mode — clone one network \(N\) times automatically:

workflow = bf.EnsembleWorkflow(
    simulator=simulator,
    inference_variables=["theta"],
    summary_variables=["x"],
    inference_networks=bf.networks.FlowMatching(),
    ensemble_size=2,
)

Dictionary mode — give each member an explicit name and network:

workflow = bf.EnsembleWorkflow(
    simulator=simulator,
    inference_variables=["theta"],
    summary_variables=["x"],
    inference_networks={
        "flow": bf.networks.FlowMatching(),
        "coupling": bf.networks.CouplingFlow()
    },
    summary_networks={
        "flow": bf.networks.TimeSeriesTransformer(summary_dim=16),
        "coupling": bf.networks.TimeSeriesTransformer(summary_dim=16)
    },
)

7.2.2. Training#

EnsembleWorkflow exposes the same fit_online / fit_offline / fit_disk interface as BasicWorkflow, plus a data_reuse parameter that controls data sharing between members:

history = workflow.fit_online(
    epochs=5,
    batch_size=64,
    num_batches_per_epoch=50,
    data_reuse=0.8,
)
test_batch = simulator.sample(4)
x_obs = test_batch["x"].astype("float32")
theta = test_batch["theta"].astype("float32")

7.2.3. Inference#

# Merged mixture samples (default)
samples = workflow.sample(num_samples=500, conditions={"x": x_obs})

# Per-member samples
samples = workflow.sample(
    num_samples=500,
    conditions={"x": x_obs},
    merge_members=False,   # returns {"flow": ..., "coupling": ...}
)

# Weighted mixture
samples = workflow.sample(
    num_samples=500,
    conditions={"x": x_obs},
    member_weights={"flow": 0.6, "coupling": 0.4},
)

# Log-probability under the mixture
log_p = workflow.log_prob(data={"x": x_obs, "theta": theta})
workflow = basic_workflow  # use the trained BasicWorkflow for diagnostics
test_data = simulator.sample(100)

7.3. CompositionalWorkflow#

CompositionalWorkflow extends BasicWorkflow for compositional inference — settings where you want to draw posterior samples conditioned on multiple datasets simultaneously (e.g., combining evidence from several experiments or performing joint inference over a hierarchical structure). The inference network must be a DiffusionModel.

7.3.1. Basic usage#

workflow = bf.CompositionalWorkflow(
    simulator=simulator,
    inference_variables=["theta"],
    summary_variables=["x"],
    inference_network=bf.networks.DiffusionModel(),
    summary_network=bf.networks.TimeSeriesTransformer(summary_dim=8),
)

history = workflow.fit_online(epochs=5, batch_size=64, num_batches_per_epoch=50)
raw = simulator.sample(6)
# compositional_sample expects (n_datasets, n_compositional, ...)
x_obs = raw["x"].reshape(2, 3, 20, 1).astype("float32")  # 2 datasets x 3 obs each

7.3.2. Compositional sampling#

Conditions are expected to have shape (n_datasets, n_compositional, ...) — the second axis holds the multiple datasets being composed.

# x_obs has shape (n_datasets, n_compositional, time, features)
samples = workflow.compositional_sample(
    num_samples=100,
    conditions={"x": x_obs},
)

You can also inject a prior score function to guide compositional sampling:

prior_variance = 1.0  # theta ~ N(0, 1) so prior variance is 1

def prior_score(data, t=None):
    """Score of a Gaussian prior on theta."""
    return {"theta": -data["theta"] / prior_variance}

samples = workflow.compositional_sample(
    num_samples=100,
    conditions={"x": x_obs},
    compute_prior_score=prior_score,
)

7.3.3. Building from a trained BasicWorkflow#

If you already have a trained BasicWorkflow with a DiffusionModel, you can promote it to a CompositionalWorkflow without retraining. The network weights are cloned and transferred:

# CompositionalWorkflow.from_basic_workflow() promotes a trained BasicWorkflow
# (inference_network must be DiffusionModel) into a CompositionalWorkflow.
# Example (not run here because basic_workflow uses FlowMatching):
#   compositional_workflow = bf.CompositionalWorkflow.from_basic_workflow(
#       basic_workflow, simulator=simulator
#   )

# Reuse the CompositionalWorkflow trained above:
compositional_workflow = workflow
samples = compositional_workflow.compositional_sample(
    num_samples=100,
    conditions={"x": x_obs},
)

See the Compositional Diffusion example for a worked end-to-end demonstration.

7.4. Training Strategies#

All three workflow classes support the same three training strategies. Choose based on how you generate data:

7.4.1. fit_online — simulate during training#

The simulator is called at each training step. This requires no pre-generated dataset and prevents overfitting to a fixed sample pool. Use this when simulation is fast (seconds per batch).

history = basic_workflow.fit_online(
    epochs=5,
    num_batches_per_epoch=50,
    batch_size=64,
    validation_data=100,    # generate 100 validation samples from the simulator
)

7.4.2. fit_offline — train on a pre-generated dataset in memory#

Simulate everything upfront and store it in a dictionary. Training is faster (no simulator overhead per step), but the model trains on a fixed finite dataset.

# Pre-generate data
data = simulator.sample(500)

validation_data = simulator.sample(100)

history = basic_workflow.fit_offline(
    data=data,
    epochs=5,
    batch_size=64,
    validation_data=validation_data,
)

7.4.3. fit_disk — stream from files on disk#

For very large simulation budgets that don’t fit in memory. Data is stored as .pkl files (or any format with a custom load_fn) and loaded lazily during training.

# Illustrative — requires files on disk at the given path.
# Each .pkl file should be a dict with keys matching the simulator output.
#
# history = basic_workflow.fit_disk(
#     root="/data/simulations",
#     pattern="*.pkl",
#     batch_size=128,
#     epochs=100
# )

7.4.4. Augmentations#

All fit methods accept an augmentations argument — functions applied to each batch before the adapter, useful for data augmentation during training only (e.g., adding noise, random scaling):

def add_noise(batch):
    batch["x"] = batch["x"] + np.random.normal(0, 0.01, batch["x"].shape).astype("float32")
    return batch

basic_workflow.fit_online(epochs=2, batch_size=64, num_batches_per_epoch=10,
                          augmentations=add_noise)

7.5. Diagnostics#

All workflows expose built-in diagnostic utilities that run automatically against test data.

7.5.1. Default diagnostics#

# plot_default_diagnostics takes the raw test data and handles sampling internally
figures = workflow.plot_default_diagnostics(test_data, num_samples=100)
# figures is a dict: {"losses", "recovery", "calibration_ecdf", "coverage", "z_score_contraction"}

Individual plots:

  • losses — training (and validation) loss over epochs

  • recovery — posterior mean vs. ground-truth parameter (“how well does the posterior center on the truth?”)

  • calibration_ecdf — empirical coverage: marginal posteriors should be uniformly calibrated

  • coverage — joint credible-interval coverage across parameter dimensions

  • z_score_contraction — how much the posterior contracts relative to the prior

inference_variables is required for automated diagnostics. Make sure to pass inference_variables=["theta"] (or your parameter names) when constructing the workflow; otherwise the diagnostic methods cannot automatically identify the targets.

7.5.2. Accessing metrics numerically#

metrics_df = workflow.compute_default_diagnostics(test_data, num_samples=100, as_data_frame=True)

7.5.3. Custom diagnostics#

Pass a dictionary of {name: plot_fn} for custom visualizations:

import matplotlib.pyplot as plt


def posterior_predictive_check(samples, test_data, variable_keys=None, variable_names=None, **kwargs):
    """Overlay observed data, posterior predictive density, and true density for a few test cases."""
    n_show = min(4, len(test_data["x"]))
    theta_post = samples["theta"]     # (n_test, n_samples, 2)
    x_obs      = test_data["x"]       # (n_test, 20, 1)
    theta_true = test_data["theta"]   # (n_test, 2)

    fig, axes = plt.subplots(1, n_show, figsize=(3.5 * n_show, 3.5))
    axes = [axes] if n_show == 1 else list(axes)

    for i, ax in enumerate(axes):
        obs = x_obs[i].reshape(-1)
        lo, hi = obs.min() - 1.0, obs.max() + 1.0
        x_grid = np.linspace(lo, hi, 300)

        # Observed data
        ax.hist(obs, bins=10, density=True, alpha=0.4, color="steelblue", label="Observed x")

        # Posterior predictive: average normal PDF over 50 posterior draws
        mu_s = theta_post[i, :50, 0]
        sg_s = np.exp(theta_post[i, :50, 1] * 0.5)
        ppd = np.mean(
            [np.exp(-0.5 * ((x_grid - m) / s) ** 2) / (s * np.sqrt(2 * np.pi))
             for m, s in zip(mu_s, sg_s)],
            axis=0,
        )
        ax.plot(x_grid, ppd, color="tomato", lw=2, label="Posterior predictive")

        # True generative density
        m_t = theta_true[i, 0]
        s_t = float(np.exp(theta_true[i, 1] * 0.5))
        true_pdf = np.exp(-0.5 * ((x_grid - m_t) / s_t) ** 2) / (s_t * np.sqrt(2 * np.pi))
        ax.plot(x_grid, true_pdf, "k--", lw=1.5, label="True density")

        ax.set_title(f"Test case {i + 1}", fontsize=10)
        ax.set_xlabel("x")

    axes[0].set_ylabel("Density")
    axes[0].legend(fontsize=8)
    fig.suptitle("Posterior Predictive Check", fontweight="bold")
    fig.tight_layout()
    return fig


figures = workflow.plot_custom_diagnostics(
    test_data,
    plot_fns={"posterior_predictive": posterior_predictive_check},
    num_samples=200,
)