6. Inference Networks#
Inference networks are the learnable component inside an Approximator that maps a simple base distribution (typically a standard normal) to the target posterior.
BayesFlow provides several inference networks, each with different trade-offs between expressivity, inference speed, and density evaluation.
You can find all inference networks in the networks module.
Networks that extend InferenceNetwork support full posterior sampling
and, where possible, exact log-density evaluation.
6.1. Overview#
Network |
Architecture |
Inference speed |
Exact density |
Best for |
|---|---|---|---|---|
|
Normalizing flow |
⚡⚡⚡ |
✓ |
Simple to moderately complex posteriors |
|
Flow matching (OT) |
⚡ |
✓ |
Complex, multimodal posteriors |
|
Score-based SDE/ODE |
⚡ |
✓ |
Complex, multimodal; compositional inference |
|
Consistency training |
⚡⚡ |
— |
Fast single-step distillation training |
|
Stable CT |
⚡⚡ |
— |
Continuous consistency training without discretization |
|
Feed-forward + heads |
⚡⚡⚡⚡ |
Parametric |
Various Bayes estimators and parametric distributions |
|
Feed-forward + heads |
⚡⚡⚡⚡ |
— |
Posterior mean and quantile estimation |
All networks are instantiated with sensible defaults and can be passed directly to an approximator:
import bayesflow as bf
approximator = bf.ContinuousApproximator(
inference_network=bf.networks.FlowMatching(),
summary_network=bf.networks.SetTransformer() # optional summary backbone
)
6.2. CouplingFlow#
CouplingFlow is the classical choice for simulation-based inference.
It chains invertible coupling layers, each of which transforms one half of the latent vector conditioned
on the other half. Because the transformation is analytically invertible, the log-density can be
computed exactly — the training loss (negative log-likelihood) is a direct quality indicator throughout.
When to use: simple to moderate posteriors, whenever exact density is needed (e.g. for
RatioApproximator cross-checks), or when inference speed is critical.
Key parameters
Parameter |
Default |
Notes |
|---|---|---|
|
|
Backbone for each coupling layer |
|
|
Number of invertible layers; increase for more complex posteriors |
|
|
|
|
|
|
|
|
ActNorm normalization before each coupling layer |
|
|
|
|
|
Pass |
# Default: fast and robust for most problems
inference_network = bf.networks.CouplingFlow()
# Deeper with spline transforms for more expressive posteriors
inference_network = bf.networks.CouplingFlow(
depth=10,
transform="spline",
subnet_kwargs={"widths": [256, 256, 256], "dropout": 0.05},
)
# Heavy-tailed base for posteriors with outliers
inference_network = bf.networks.CouplingFlow(
base_distribution="normal",
)
approximator = bf.ContinuousApproximator(
inference_network=inference_network,
summary_network=bf.networks.SetTransformer(),
)
6.3. FlowMatching#
FlowMatching implements Optimal Transport Flow Matching (originally Rectified Flow).
A neural network learns a velocity field that transports samples from the base distribution straight to
the posterior. At inference time, the ODE is integrated numerically — typically requiring 10–50 network
evaluations per sample. This makes it slower than coupling flows but far more expressive, as the learned
trajectories are not restricted to coordinate-wise invertible transforms.
When to use: complex or multimodal posteriors where CouplingFlow under-fits; the additional
inference cost (ODE integration) is usually well worth the expressivity gain.
Note
After an initial drop, the MSE training loss is no longer a reliable performance indicator.
Use diagnostic plots (e.g. recovery()) during and after training.
Key parameters
Parameter |
Default |
Notes |
|---|---|---|
|
|
Velocity-field backbone; accepts time as an additional input |
|
|
Mini-batch OT via Sinkhorn — ~2.5× slower training but often faster convergence |
|
|
Pass |
|
|
Biases time sampling; positive values oversample late times |
|
|
|
|
|
Classifier-free guidance dropout probability |
# Default: solid choice for most complex posteriors
inference_network = bf.networks.FlowMatching()
# With optimal transport for improved training stability
inference_network = bf.networks.FlowMatching(
use_optimal_transport=True,
subnet_kwargs={"widths": [256, 256, 256], "dropout": 0.05},
)
# Fewer ODE steps at inference for faster (slightly less accurate) sampling
inference_network = bf.networks.FlowMatching(
integrate_kwargs={"steps": 20},
)
approximator = bf.ContinuousApproximator(
inference_network=inference_network,
summary_network=bf.networks.TimeSeriesTransformer()
)
6.4. DiffusionModel#
DiffusionModel implements a score-based diffusion model for amortized SBI.
A noise schedule gradually corrupts posterior samples during training; a neural network learns to reverse
this process at inference time by estimating the score function. Inference integrates a reverse-time
SDE or ODE, making it the most computationally expensive sampler — but also the richest and most flexible.
DiffusionModel is also the only inference network that supports
CompositionalWorkflow, which composes multiple independently trained
posteriors at inference time (see the Workflows guide).
See also
The BayesFlow diffusion experiments site provides detailed benchmarks, worked examples, and guidance on choosing schedules, prediction types, and integration settings for diffusion models in SBI.
Key parameters
Parameter |
Default |
Notes |
|---|---|---|
|
|
Score-network backbone |
|
|
|
|
|
|
|
|
|
|
|
ODE/SDE integrator settings (steps, method, etc.); can be toggled in |
|
|
Classifier-free guidance dropout |
|
|
|
# Default: EDM schedule, recommended starting point
inference_network = bf.networks.DiffusionModel()
# Cosine schedule with velocity prediction
inference_network = bf.networks.DiffusionModel(
noise_schedule="cosine",
prediction_type="velocity",
loss_type="velocity",
subnet_kwargs={"widths": [256, 256, 256]},
)
# Reduce integration steps at inference for faster sampling
inference_network = bf.networks.DiffusionModel(
integrate_kwargs={"steps": 20},
)
approximator = bf.ContinuousApproximator(
inference_network=inference_network,
summary_network=bf.networks.SetTransformer(),
)
6.5. Consistency Models#
Consistency models distill the multi-step reverse-diffusion process into a mapping that can, in principle, generate a sample in a single network evaluation. BayesFlow provides two variants:
ConsistencyModel— implements Consistency Training (CT) with progressive discretization schedule, as in Song et al. (2023).StableConsistencyModel— implements the simple, stable, and scalable variant (sCM; Lu & Song 2024) that removes the need for a fixedtotal_stepscount and uses a learned weighting network for a more stable loss.
See also
The BayesFlow diffusion experiments site
contains benchmarks and examples for consistency models in SBI, including comparisons with
DiffusionModel and FlowMatching.
6.5.1. ConsistencyModel#
ConsistencyModel requires knowing total_steps = num_epochs × num_batches at construction time
so that the discretization schedule can be pre-computed. Always pass this exactly.
Parameter |
Default |
Notes |
|---|---|---|
|
— |
Required. |
|
|
Backbone network |
|
|
Initial / final discretization steps |
|
|
Maximum noise level |
|
|
Classifier-free guidance dropout |
|
|
MLP widths, dropout, etc. |
num_epochs = 100
num_batches = 250 # = dataset_size / batch_size
inference_network = bf.networks.ConsistencyModel(
total_steps=num_epochs * num_batches,
)
# Larger discretization range for better quality
inference_network = bf.networks.ConsistencyModel(
total_steps=num_epochs * num_batches,
s0=2,
s1=200,
subnet_kwargs={"widths": [256, 256, 256]},
)
approximator = bf.ContinuousApproximator(
inference_network=inference_network,
summary_network=bf.networks.SetTransformer(),
)
6.5.2. StableConsistencyModel#
StableConsistencyModel needs no total_steps and is generally easier to configure.
An auxiliary weight MLP automatically scales the loss during training.
Parameter |
Default |
Notes |
|---|---|---|
|
|
Backbone network |
|
|
Noise standard deviation |
|
|
MLP widths, dropout, etc. |
|
|
Kwargs for the auxiliary weighting MLP |
|
|
Classifier-free guidance dropout |
# Drop-in replacement for ConsistencyModel — no total_steps required
inference_network = bf.networks.StableConsistencyModel()
inference_network = bf.networks.StableConsistencyModel(
subnet_kwargs={"widths": [256, 256, 256], "dropout": 0.05},
)
approximator = bf.ContinuousApproximator(
inference_network=inference_network,
summary_network=bf.networks.SetTransformer(),
)
6.6. Image Generation and Spatial Outputs#
All diffusion-like networks — DiffusionModel,
FlowMatching, and the consistency model variants — can generate
image-valued posterior samples when paired with an image-capable subnet backbone.
This is useful whenever the inferential target is itself an image or spatial field (e.g. Bayesian
denoising, Gaussian random field estimation, spatial emulation), rather than a low-dimensional
parameter vector.
Note
If the image is only an observed condition and the inferential target is a small parameter vector,
keep the default workflow and use ConvolutionalNetwork as a
summary network instead. Image generation applies only when the target is image-shaped.
6.6.1. Subnet choice#
Choose an image backbone in order of increasing capacity:
Subnet |
When to use |
|---|---|
|
Default starting point; denoising, smaller images, simpler spatial structure |
|
When |
|
Hardest tasks; try when |
Escalate capacity only after training has converged and held-out diagnostics still show weak recovery.
6.6.2. Shape requirements#
Image-generation subnets concatenate the condition channel-wise with the target image.
If the target is (B, H, W, C), the condition must already have compatible spatial shape (B, H, W, D).
A global condition (B, D) must be tiled or broadcast to spatial shape — either inside the simulator
or in the adapter — before it reaches the inference network.
6.6.3. Minimal example#
# Here is a very simple Gaussian denoising model
import numpy as np
H, W = 8, 8
def prior():
return dict(target_image=np.random.normal(size=(H, W, 1)).astype("float32"))
def likelihood(target_image):
noise = np.random.normal(0, 0.3, size=(H, W, 1)).astype("float32")
return dict(condition_map=(target_image + noise).astype("float32"))
simulator = bf.make_simulator([prior, likelihood])
# Image target: (B, H, W, C); condition map is a noisy version of the same spatial shape
adapter = (
bf.Adapter()
.rename("target_image", "inference_variables")
.rename("condition_map", "inference_conditions")
)
# DiffusionModel with UNet backbone — recommended default for image generation
inference_network = bf.networks.DiffusionModel(
subnet=bf.networks.UNet(),
prediction_type="velocity",
noise_schedule="cosine"
)
workflow = bf.BasicWorkflow(
simulator=simulator,
inference_network=inference_network,
adapter=adapter,
initial_learning_rate=1e-4
)
workflow.fit_online(epochs=5, batch_size=8, num_batches_per_epoch=10)
The same pattern works with FlowMatching:
inference_network = bf.networks.FlowMatching(subnet=bf.networks.UViT())
6.6.4. Checking quality#
Standard plot_default_diagnostics is designed for low-dimensional parameter targets.
For image-valued outputs, inspect samples visually:
held_out = workflow.simulate(5)
samples = workflow.sample(
conditions={"condition_map": held_out["condition_map"]},
num_samples=4
)
generated = samples["target_image"] # shape (5, 4, H, W, C)
# Plot a small grid across held-out conditions for quick inspection
See also
A worked example with spatial data is available in the Spatial Data and Parameters example notebook.
6.7. Point Estimation Networks#
The networks below are not generative models — they do not produce posterior samples.
Instead, they minimize a Bayes risk (a scoring rule) to recover specific posterior summaries such
as the posterior mean, quantiles, or the parameters of a parametric distribution. They are used with
ScoringRuleApproximator and can be a fast and frugal alternative to fully Bayesian estimation.
6.7.1. ScoringRuleNetwork#
ScoringRuleNetwork gives full control over the scoring rules applied.
Pass any combination of ScoringRule instances — including
parametric distribution fits such as MvNormalScore.
from bayesflow.scoring_rules import MeanScore, QuantileScore, MvNormalScore
# Fit a multivariate normal approximation to the posterior
inference_network = bf.networks.ScoringRuleNetwork(
scoring_rules={"mvn": MvNormalScore()}
)
# Mix different scoring rules
inference_network = bf.networks.ScoringRuleNetwork(
scoring_rules={
"mean": MeanScore(),
"quantiles": QuantileScore(q=[0.1, 0.5, 0.9])
}
)
approximator = bf.ScoringRuleApproximator(
inference_network=inference_network,
summary_network=bf.networks.SetTransformer()
)
The output of .estimate() is a nested dict keyed by scoring rule name, then by the statistic
(e.g. "mean", "quantiles", or distribution parameters such as "loc" and "scale").
6.7.2. PointNetwork#
PointNetwork is a thin wrapper around ScoringRuleNetwork with a
simplified interface for the two most common point estimates: the posterior mean and posterior quantiles.
# Estimate the posterior mean only
inference_network = bf.networks.PointNetwork("mean")
# Estimate several quantiles
inference_network = bf.networks.PointNetwork(
["mean", "quantiles"],
q=[0.05, 0.25, 0.5, 0.75, 0.95]
)
approximator = bf.ScoringRuleApproximator(
inference_network=inference_network,
summary_network=bf.networks.SetTransformer()
)