6. Inference Networks#

Inference networks are the learnable component inside an Approximator that maps a simple base distribution (typically a standard normal) to the target posterior.

BayesFlow provides several inference networks, each with different trade-offs between expressivity, inference speed, and density evaluation.

You can find all inference networks in the networks module. Networks that extend InferenceNetwork support full posterior sampling and, where possible, exact log-density evaluation.

6.1. Overview#

Network

Architecture

Inference speed

Exact density

Best for

CouplingFlow

Normalizing flow

⚡⚡⚡

Simple to moderately complex posteriors

FlowMatching

Flow matching (OT)

Complex, multimodal posteriors

DiffusionModel

Score-based SDE/ODE

Complex, multimodal; compositional inference

ConsistencyModel

Consistency training

⚡⚡

Fast single-step distillation training

StableConsistencyModel

Stable CT

⚡⚡

Continuous consistency training without discretization

ScoringRuleNetwork

Feed-forward + heads

⚡⚡⚡⚡

Parametric

Various Bayes estimators and parametric distributions

PointNetwork

Feed-forward + heads

⚡⚡⚡⚡

Posterior mean and quantile estimation

All networks are instantiated with sensible defaults and can be passed directly to an approximator:

import bayesflow as bf
approximator = bf.ContinuousApproximator(
    inference_network=bf.networks.FlowMatching(),
    summary_network=bf.networks.SetTransformer() # optional summary backbone
)

6.2. CouplingFlow#

CouplingFlow is the classical choice for simulation-based inference. It chains invertible coupling layers, each of which transforms one half of the latent vector conditioned on the other half. Because the transformation is analytically invertible, the log-density can be computed exactly — the training loss (negative log-likelihood) is a direct quality indicator throughout.

When to use: simple to moderate posteriors, whenever exact density is needed (e.g. for RatioApproximator cross-checks), or when inference speed is critical.

Key parameters

Parameter

Default

Notes

subnet

"mlp"

Backbone for each coupling layer

depth

6

Number of invertible layers; increase for more complex posteriors

transform

"affine"

"affine" (fast) or "spline" (more expressive, heavier)

permutation

"random"

"random", "orthogonal", "swap", or None

use_actnorm

True

ActNorm normalization before each coupling layer

base_distribution

"normal"

"normal", "student", "mixture" or a custom Distribution

subnet_kwargs

None

Pass widths, dropout, activation, etc. to the MLP

# Default: fast and robust for most problems
inference_network = bf.networks.CouplingFlow()

# Deeper with spline transforms for more expressive posteriors
inference_network = bf.networks.CouplingFlow(
    depth=10,
    transform="spline",
    subnet_kwargs={"widths": [256, 256, 256], "dropout": 0.05},
)

# Heavy-tailed base for posteriors with outliers
inference_network = bf.networks.CouplingFlow(
    base_distribution="normal",
)

approximator = bf.ContinuousApproximator(
    inference_network=inference_network,
    summary_network=bf.networks.SetTransformer(),
)

6.3. FlowMatching#

FlowMatching implements Optimal Transport Flow Matching (originally Rectified Flow). A neural network learns a velocity field that transports samples from the base distribution straight to the posterior. At inference time, the ODE is integrated numerically — typically requiring 10–50 network evaluations per sample. This makes it slower than coupling flows but far more expressive, as the learned trajectories are not restricted to coordinate-wise invertible transforms.

When to use: complex or multimodal posteriors where CouplingFlow under-fits; the additional inference cost (ODE integration) is usually well worth the expressivity gain.

Note

After an initial drop, the MSE training loss is no longer a reliable performance indicator. Use diagnostic plots (e.g. recovery()) during and after training.

Key parameters

Parameter

Default

Notes

subnet

"time_mlp"

Velocity-field backbone; accepts time as an additional input

use_optimal_transport

False

Mini-batch OT via Sinkhorn — ~2.5× slower training but often faster convergence

integrate_kwargs

{}

Pass steps, method, etc. to the ODE solver; can be passed to .sample() too

time_power_law_alpha

0.0

Biases time sampling; positive values oversample late times

subnet_kwargs

None

widths, dropout, etc. for the MLP

drop_cond_prob

0.0

Classifier-free guidance dropout probability

# Default: solid choice for most complex posteriors
inference_network = bf.networks.FlowMatching()

# With optimal transport for improved training stability
inference_network = bf.networks.FlowMatching(
    use_optimal_transport=True,
    subnet_kwargs={"widths": [256, 256, 256], "dropout": 0.05},
)

# Fewer ODE steps at inference for faster (slightly less accurate) sampling
inference_network = bf.networks.FlowMatching(
    integrate_kwargs={"steps": 20},
)

approximator = bf.ContinuousApproximator(
    inference_network=inference_network,
    summary_network=bf.networks.TimeSeriesTransformer()
)

6.4. DiffusionModel#

DiffusionModel implements a score-based diffusion model for amortized SBI. A noise schedule gradually corrupts posterior samples during training; a neural network learns to reverse this process at inference time by estimating the score function. Inference integrates a reverse-time SDE or ODE, making it the most computationally expensive sampler — but also the richest and most flexible.

DiffusionModel is also the only inference network that supports CompositionalWorkflow, which composes multiple independently trained posteriors at inference time (see the Workflows guide).

See also

The BayesFlow diffusion experiments site provides detailed benchmarks, worked examples, and guidance on choosing schedules, prediction types, and integration settings for diffusion models in SBI.

Key parameters

Parameter

Default

Notes

subnet

"time_mlp"

Score-network backbone

noise_schedule

"edm"

"edm" (recommended) or "cosine"

prediction_type

"F"

"F", "noise", "velocity", "x", "score", "potential"

loss_type

"noise"

"noise", "velocity", or "F"

integrate_kwargs

{}

ODE/SDE integrator settings (steps, method, etc.); can be toggled in sample() too

drop_cond_prob

0.0

Classifier-free guidance dropout

subnet_kwargs

None

widths, dropout, etc. for the MLP

# Default: EDM schedule, recommended starting point
inference_network = bf.networks.DiffusionModel()

# Cosine schedule with velocity prediction
inference_network = bf.networks.DiffusionModel(
    noise_schedule="cosine",
    prediction_type="velocity",
    loss_type="velocity",
    subnet_kwargs={"widths": [256, 256, 256]},
)

# Reduce integration steps at inference for faster sampling
inference_network = bf.networks.DiffusionModel(
    integrate_kwargs={"steps": 20},
)

approximator = bf.ContinuousApproximator(
    inference_network=inference_network,
    summary_network=bf.networks.SetTransformer(),
)

6.5. Consistency Models#

Consistency models distill the multi-step reverse-diffusion process into a mapping that can, in principle, generate a sample in a single network evaluation. BayesFlow provides two variants:

  • ConsistencyModel — implements Consistency Training (CT) with progressive discretization schedule, as in Song et al. (2023).

  • StableConsistencyModel — implements the simple, stable, and scalable variant (sCM; Lu & Song 2024) that removes the need for a fixed total_steps count and uses a learned weighting network for a more stable loss.

See also

The BayesFlow diffusion experiments site contains benchmarks and examples for consistency models in SBI, including comparisons with DiffusionModel and FlowMatching.

6.5.1. ConsistencyModel#

ConsistencyModel requires knowing total_steps = num_epochs × num_batches at construction time so that the discretization schedule can be pre-computed. Always pass this exactly.

Parameter

Default

Notes

total_steps

Required. num_epochs * num_batches

subnet

"time_mlp"

Backbone network

s0 / s1

10 / 150

Initial / final discretization steps

max_time

80

Maximum noise level

drop_cond_prob

0.0

Classifier-free guidance dropout

subnet_kwargs

None

MLP widths, dropout, etc.

num_epochs = 100
num_batches = 250  # = dataset_size / batch_size

inference_network = bf.networks.ConsistencyModel(
    total_steps=num_epochs * num_batches,
)

# Larger discretization range for better quality
inference_network = bf.networks.ConsistencyModel(
    total_steps=num_epochs * num_batches,
    s0=2,
    s1=200,
    subnet_kwargs={"widths": [256, 256, 256]},
)

approximator = bf.ContinuousApproximator(
    inference_network=inference_network,
    summary_network=bf.networks.SetTransformer(),
)

6.5.2. StableConsistencyModel#

StableConsistencyModel needs no total_steps and is generally easier to configure. An auxiliary weight MLP automatically scales the loss during training.

Parameter

Default

Notes

subnet

"time_mlp"

Backbone network

sigma

1.0

Noise standard deviation

subnet_kwargs

None

MLP widths, dropout, etc.

weight_mlp_kwargs

None

Kwargs for the auxiliary weighting MLP

drop_cond_prob

0.0

Classifier-free guidance dropout

# Drop-in replacement for ConsistencyModel — no total_steps required
inference_network = bf.networks.StableConsistencyModel()

inference_network = bf.networks.StableConsistencyModel(
    subnet_kwargs={"widths": [256, 256, 256], "dropout": 0.05},
)

approximator = bf.ContinuousApproximator(
    inference_network=inference_network,
    summary_network=bf.networks.SetTransformer(),
)

6.6. Image Generation and Spatial Outputs#

All diffusion-like networks — DiffusionModel, FlowMatching, and the consistency model variants — can generate image-valued posterior samples when paired with an image-capable subnet backbone. This is useful whenever the inferential target is itself an image or spatial field (e.g. Bayesian denoising, Gaussian random field estimation, spatial emulation), rather than a low-dimensional parameter vector.

Note

If the image is only an observed condition and the inferential target is a small parameter vector, keep the default workflow and use ConvolutionalNetwork as a summary network instead. Image generation applies only when the target is image-shaped.

6.6.1. Subnet choice#

Choose an image backbone in order of increasing capacity:

Subnet

When to use

UNet

Default starting point; denoising, smaller images, simpler spatial structure

UViT

When UNet underfits or long-range spatial interactions matter

ResidualUViT

Hardest tasks; try when UViT diagnostics still fail

Escalate capacity only after training has converged and held-out diagnostics still show weak recovery.

6.6.2. Shape requirements#

Image-generation subnets concatenate the condition channel-wise with the target image. If the target is (B, H, W, C), the condition must already have compatible spatial shape (B, H, W, D). A global condition (B, D) must be tiled or broadcast to spatial shape — either inside the simulator or in the adapter — before it reaches the inference network.

6.6.3. Minimal example#

# Here is a very simple Gaussian denoising model
import numpy as np

H, W = 8, 8

def prior():
    return dict(target_image=np.random.normal(size=(H, W, 1)).astype("float32"))

def likelihood(target_image):
    noise = np.random.normal(0, 0.3, size=(H, W, 1)).astype("float32")
    return dict(condition_map=(target_image + noise).astype("float32"))
simulator = bf.make_simulator([prior, likelihood])

# Image target: (B, H, W, C); condition map is a noisy version of the same spatial shape
adapter = (
    bf.Adapter()
    .rename("target_image", "inference_variables")
    .rename("condition_map", "inference_conditions")
)

# DiffusionModel with UNet backbone — recommended default for image generation
inference_network = bf.networks.DiffusionModel(
    subnet=bf.networks.UNet(),
    prediction_type="velocity",
    noise_schedule="cosine"
)
workflow = bf.BasicWorkflow(
    simulator=simulator,
    inference_network=inference_network,
    adapter=adapter,
    initial_learning_rate=1e-4
)

workflow.fit_online(epochs=5, batch_size=8, num_batches_per_epoch=10)

The same pattern works with FlowMatching:

inference_network = bf.networks.FlowMatching(subnet=bf.networks.UViT())

6.6.4. Checking quality#

Standard plot_default_diagnostics is designed for low-dimensional parameter targets. For image-valued outputs, inspect samples visually:

held_out = workflow.simulate(5)

samples = workflow.sample(
    conditions={"condition_map": held_out["condition_map"]},
    num_samples=4
)

generated = samples["target_image"]  # shape (5, 4, H, W, C)
# Plot a small grid across held-out conditions for quick inspection

See also

A worked example with spatial data is available in the Spatial Data and Parameters example notebook.

6.7. Point Estimation Networks#

The networks below are not generative models — they do not produce posterior samples. Instead, they minimize a Bayes risk (a scoring rule) to recover specific posterior summaries such as the posterior mean, quantiles, or the parameters of a parametric distribution. They are used with ScoringRuleApproximator and can be a fast and frugal alternative to fully Bayesian estimation.

6.7.1. ScoringRuleNetwork#

ScoringRuleNetwork gives full control over the scoring rules applied. Pass any combination of ScoringRule instances — including parametric distribution fits such as MvNormalScore.

from bayesflow.scoring_rules import MeanScore, QuantileScore, MvNormalScore

# Fit a multivariate normal approximation to the posterior
inference_network = bf.networks.ScoringRuleNetwork(
    scoring_rules={"mvn": MvNormalScore()}
)

# Mix different scoring rules
inference_network = bf.networks.ScoringRuleNetwork(
    scoring_rules={
        "mean": MeanScore(),
        "quantiles": QuantileScore(q=[0.1, 0.5, 0.9])
    }
)

approximator = bf.ScoringRuleApproximator(
    inference_network=inference_network,
    summary_network=bf.networks.SetTransformer()
)

The output of .estimate() is a nested dict keyed by scoring rule name, then by the statistic (e.g. "mean", "quantiles", or distribution parameters such as "loc" and "scale").

6.7.2. PointNetwork#

PointNetwork is a thin wrapper around ScoringRuleNetwork with a simplified interface for the two most common point estimates: the posterior mean and posterior quantiles.

# Estimate the posterior mean only
inference_network = bf.networks.PointNetwork("mean")

# Estimate several quantiles
inference_network = bf.networks.PointNetwork(
    ["mean", "quantiles"],
    q=[0.05, 0.25, 0.5, 0.75, 0.95]
)

approximator = bf.ScoringRuleApproximator(
    inference_network=inference_network,
    summary_network=bf.networks.SetTransformer()
)