import numpy as np
import bayesflow as bf
# Minimal simulator: 2 parameters, 20 i.i.d. observations
def prior():
return dict(theta=np.random.normal(size=2).astype("float32"))
def likelihood(theta):
mu, sigma = theta[0], float(np.exp(theta[1] * 0.5))
return dict(x=np.random.normal(mu, sigma, size=(20, 1)).astype("float32"))
simulator = bf.make_simulator([prior, likelihood])
7. Workflows#
A workflow is a high-level wrapper around an approximator that manages the full inference pipeline: data generation, compilation, training, checkpointing, sampling, and diagnostics. For most use cases a workflow is the most convenient entry point — it handles the boilerplate so you can focus on your model.
BayesFlow provides three workflow classes:
Workflow |
Use case |
|---|---|
|
Standard posterior / likelihood estimation (NPE, NLE, scoring rules) |
|
Train and query multiple approximators jointly |
|
Compositional / hierarchical inference with a |
All three live in bayesflow.workflows and are re-exported from the top-level bayesflow namespace. This page covers each in turn, then explains the training strategies (fit_online, fit_offline, fit_disk) and the diagnostic utilities common to all workflows.
7.1. BasicWorkflow#
BasicWorkflow is the standard starting point. At a minimum you provide a simulator and tell it which variables play which role; the workflow builds a default adapter, compiles the approximator, and is ready to train.
import bayesflow as bf
workflow = bf.BasicWorkflow(
simulator=simulator,
inference_variables=["theta"],
summary_variables=["x"],
inference_network=bf.networks.FlowMatching(),
summary_network=bf.networks.TimeSeriesTransformer(summary_dim=16),
)
history = workflow.fit_online(epochs=5, batch_size=64, num_batches_per_epoch=50)
test_batch = simulator.sample(4)
x_obs = test_batch["x"].astype("float32")
theta = test_batch["theta"].astype("float32")
basic_workflow = workflow # used later in CompositionalWorkflow.from_basic_workflow
7.1.1. What BasicWorkflow builds for you#
A default
Adapterthat routesinference_variables,summary_variables, andinference_conditionscorrectly.An approximator (
ContinuousApproximatororScoringRuleApproximatordepending on the inference network).An
Adamoptimizer with a cosine decay schedule appropriate for the chosen training strategy.
You can override any of these by passing a custom adapter, optimizer, or a pre-built approximator via inference_network.
7.1.2. Key constructor arguments#
Argument |
Default |
Description |
|---|---|---|
|
|
Source of training data |
|
— |
Variables to infer (e.g., |
|
|
Variables to summarize with the summary network |
|
|
Variables fed directly to the inference network (not summarized) |
|
|
Network or name string |
|
|
Summary network or name string |
|
|
Starting LR for the optimizer |
|
|
Directory for automatic checkpointing |
|
|
File stem; saved as |
|
|
Keep only the lowest-loss checkpoint |
7.1.3. After training#
The trained approximator is always accessible at workflow.approximator. All approximator methods (.sample(), .log_prob(), etc.) work directly on it:
# High-level workflow interface
samples = workflow.sample(num_samples=1000, conditions={"x": x_obs})
# Equivalent — directly on the approximator
samples = workflow.approximator.sample(num_samples=1000, conditions={"x": x_obs})
To save the model, use the approximator (not the workflow):
workflow.approximator.save("my_model.keras")
7.2. EnsembleWorkflow#
EnsembleWorkflow trains multiple approximators jointly on the same data, then lets you query them individually or as a merged mixture. This is useful for uncertainty quantification over the inference network itself, or for comparing different network architectures on the same problem.
7.2.1. Construction modes#
Size mode — clone one network \(N\) times automatically:
workflow = bf.EnsembleWorkflow(
simulator=simulator,
inference_variables=["theta"],
summary_variables=["x"],
inference_networks=bf.networks.FlowMatching(),
ensemble_size=2,
)
Dictionary mode — give each member an explicit name and network:
workflow = bf.EnsembleWorkflow(
simulator=simulator,
inference_variables=["theta"],
summary_variables=["x"],
inference_networks={
"flow": bf.networks.FlowMatching(),
"coupling": bf.networks.CouplingFlow()
},
summary_networks={
"flow": bf.networks.TimeSeriesTransformer(summary_dim=16),
"coupling": bf.networks.TimeSeriesTransformer(summary_dim=16)
},
)
7.2.2. Training#
EnsembleWorkflow exposes the same fit_online / fit_offline / fit_disk interface as BasicWorkflow, plus a data_reuse parameter that controls data sharing between members:
history = workflow.fit_online(
epochs=5,
batch_size=64,
num_batches_per_epoch=50,
data_reuse=0.8,
)
test_batch = simulator.sample(4)
x_obs = test_batch["x"].astype("float32")
theta = test_batch["theta"].astype("float32")
7.2.3. Inference#
# Merged mixture samples (default)
samples = workflow.sample(num_samples=500, conditions={"x": x_obs})
# Per-member samples
samples = workflow.sample(
num_samples=500,
conditions={"x": x_obs},
merge_members=False, # returns {"flow": ..., "coupling": ...}
)
# Weighted mixture
samples = workflow.sample(
num_samples=500,
conditions={"x": x_obs},
member_weights={"flow": 0.6, "coupling": 0.4},
)
# Log-probability under the mixture
log_p = workflow.log_prob(data={"x": x_obs, "theta": theta})
workflow = basic_workflow # use the trained BasicWorkflow for diagnostics
test_data = simulator.sample(100)
7.3. CompositionalWorkflow#
CompositionalWorkflow extends BasicWorkflow for compositional inference — settings where you want to draw posterior samples conditioned on multiple datasets simultaneously (e.g., combining evidence from several experiments or performing joint inference over a hierarchical structure). The inference network must be a DiffusionModel.
7.3.1. Basic usage#
workflow = bf.CompositionalWorkflow(
simulator=simulator,
inference_variables=["theta"],
summary_variables=["x"],
inference_network=bf.networks.DiffusionModel(),
summary_network=bf.networks.TimeSeriesTransformer(summary_dim=8),
)
history = workflow.fit_online(epochs=5, batch_size=64, num_batches_per_epoch=50)
raw = simulator.sample(6)
# compositional_sample expects (n_datasets, n_compositional, ...)
x_obs = raw["x"].reshape(2, 3, 20, 1).astype("float32") # 2 datasets x 3 obs each
7.3.2. Compositional sampling#
Conditions are expected to have shape (n_datasets, n_compositional, ...) — the second axis holds the multiple datasets being composed.
# x_obs has shape (n_datasets, n_compositional, time, features)
samples = workflow.compositional_sample(
num_samples=100,
conditions={"x": x_obs},
)
You can also inject a prior score function to guide compositional sampling:
prior_variance = 1.0 # theta ~ N(0, 1) so prior variance is 1
def prior_score(data, t=None):
"""Score of a Gaussian prior on theta."""
return {"theta": -data["theta"] / prior_variance}
samples = workflow.compositional_sample(
num_samples=100,
conditions={"x": x_obs},
compute_prior_score=prior_score,
)
7.3.3. Building from a trained BasicWorkflow#
If you already have a trained BasicWorkflow with a DiffusionModel, you can promote it to a CompositionalWorkflow without retraining. The network weights are cloned and transferred:
# CompositionalWorkflow.from_basic_workflow() promotes a trained BasicWorkflow
# (inference_network must be DiffusionModel) into a CompositionalWorkflow.
# Example (not run here because basic_workflow uses FlowMatching):
# compositional_workflow = bf.CompositionalWorkflow.from_basic_workflow(
# basic_workflow, simulator=simulator
# )
# Reuse the CompositionalWorkflow trained above:
compositional_workflow = workflow
samples = compositional_workflow.compositional_sample(
num_samples=100,
conditions={"x": x_obs},
)
See the Compositional Diffusion example for a worked end-to-end demonstration.
7.4. Training Strategies#
All three workflow classes support the same three training strategies. Choose based on how you generate data:
7.4.1. fit_online — simulate during training#
The simulator is called at each training step. This requires no pre-generated dataset and prevents overfitting to a fixed sample pool. Use this when simulation is fast (seconds per batch).
history = basic_workflow.fit_online(
epochs=5,
num_batches_per_epoch=50,
batch_size=64,
validation_data=100, # generate 100 validation samples from the simulator
)
7.4.2. fit_offline — train on a pre-generated dataset in memory#
Simulate everything upfront and store it in a dictionary. Training is faster (no simulator overhead per step), but the model trains on a fixed finite dataset.
# Pre-generate data
data = simulator.sample(500)
validation_data = simulator.sample(100)
history = basic_workflow.fit_offline(
data=data,
epochs=5,
batch_size=64,
validation_data=validation_data,
)
7.4.3. fit_disk — stream from files on disk#
For very large simulation budgets that don’t fit in memory. Data is stored as .pkl files (or any format with a custom load_fn) and loaded lazily during training.
# Illustrative — requires files on disk at the given path.
# Each .pkl file should be a dict with keys matching the simulator output.
#
# history = basic_workflow.fit_disk(
# root="/data/simulations",
# pattern="*.pkl",
# batch_size=128,
# epochs=100
# )
7.4.4. Augmentations#
All fit methods accept an augmentations argument — functions applied to each batch before the adapter, useful for data augmentation during training only (e.g., adding noise, random scaling):
def add_noise(batch):
batch["x"] = batch["x"] + np.random.normal(0, 0.01, batch["x"].shape).astype("float32")
return batch
basic_workflow.fit_online(epochs=2, batch_size=64, num_batches_per_epoch=10,
augmentations=add_noise)
7.5. Diagnostics#
All workflows expose built-in diagnostic utilities that run automatically against test data.
7.5.1. Default diagnostics#
# plot_default_diagnostics takes the raw test data and handles sampling internally
figures = workflow.plot_default_diagnostics(test_data, num_samples=100)
# figures is a dict: {"losses", "recovery", "calibration_ecdf", "coverage", "z_score_contraction"}
Individual plots:
losses— training (and validation) loss over epochsrecovery— posterior mean vs. ground-truth parameter (“how well does the posterior center on the truth?”)calibration_ecdf— empirical coverage: marginal posteriors should be uniformly calibratedcoverage— joint credible-interval coverage across parameter dimensionsz_score_contraction— how much the posterior contracts relative to the prior
inference_variablesis required for automated diagnostics. Make sure to passinference_variables=["theta"](or your parameter names) when constructing the workflow; otherwise the diagnostic methods cannot automatically identify the targets.
7.5.2. Accessing metrics numerically#
metrics_df = workflow.compute_default_diagnostics(test_data, num_samples=100, as_data_frame=True)
7.5.3. Custom diagnostics#
Pass a dictionary of {name: plot_fn} for custom visualizations:
import matplotlib.pyplot as plt
def posterior_predictive_check(samples, test_data, variable_keys=None, variable_names=None, **kwargs):
"""Overlay observed data, posterior predictive density, and true density for a few test cases."""
n_show = min(4, len(test_data["x"]))
theta_post = samples["theta"] # (n_test, n_samples, 2)
x_obs = test_data["x"] # (n_test, 20, 1)
theta_true = test_data["theta"] # (n_test, 2)
fig, axes = plt.subplots(1, n_show, figsize=(3.5 * n_show, 3.5))
axes = [axes] if n_show == 1 else list(axes)
for i, ax in enumerate(axes):
obs = x_obs[i].reshape(-1)
lo, hi = obs.min() - 1.0, obs.max() + 1.0
x_grid = np.linspace(lo, hi, 300)
# Observed data
ax.hist(obs, bins=10, density=True, alpha=0.4, color="steelblue", label="Observed x")
# Posterior predictive: average normal PDF over 50 posterior draws
mu_s = theta_post[i, :50, 0]
sg_s = np.exp(theta_post[i, :50, 1] * 0.5)
ppd = np.mean(
[np.exp(-0.5 * ((x_grid - m) / s) ** 2) / (s * np.sqrt(2 * np.pi))
for m, s in zip(mu_s, sg_s)],
axis=0,
)
ax.plot(x_grid, ppd, color="tomato", lw=2, label="Posterior predictive")
# True generative density
m_t = theta_true[i, 0]
s_t = float(np.exp(theta_true[i, 1] * 0.5))
true_pdf = np.exp(-0.5 * ((x_grid - m_t) / s_t) ** 2) / (s_t * np.sqrt(2 * np.pi))
ax.plot(x_grid, true_pdf, "k--", lw=1.5, label="True density")
ax.set_title(f"Test case {i + 1}", fontsize=10)
ax.set_xlabel("x")
axes[0].set_ylabel("Density")
axes[0].legend(fontsize=8)
fig.suptitle("Posterior Predictive Check", fontweight="bold")
fig.tight_layout()
return fig
figures = workflow.plot_custom_diagnostics(
test_data,
plot_fns={"posterior_predictive": posterior_predictive_check},
num_samples=200,
)