5. Summary Networks#
Many scientific simulators produce observations that cannot be cleanly flattened into a fixed-length vector:
A clinical trial might enroll a varying number of patients per study (exchangeable observations).
An electrophysiology recording might span different durations across experiments (time series of varying length).
A spatial simulator might output 2D fields of different resolutions (image data).
Passing such data directly to an inference network (e.g., FlowMatching) requires padding or truncation, both of which throw away structure the network could otherwise exploit. A summary network avoids this by learning a fixed-size representation \(h(x) \in \mathbb{R}^d\) that preserves the structure of the data — permutation invariance for sets, causal ordering for time series, spatial locality for images — while compressing it into a vector the inference network can consume.
The resulting summary dimension \(d\) (the summary_dim argument) also serves as a diagnostic handle: you can inspect the learned summary space to check for posterior collapse or out-of-distribution behavior.
This page covers:
Available networks — which architecture fits your data
Sizing guide — how to set
summary_dimand network capacityHyperparameters in depth — key parameters and the
time_axisargumentUsage examples — inside an approximator and inside a workflow
Custom summary networks — neural data example and a pretrained CNN backbone from Keras Hub
5.1. Available Networks#
Network |
Data type |
Expected input shape |
Notes |
|---|---|---|---|
|
Exchangeable / i.i.d. observations |
|
Permutation-invariant by design |
|
Ordered sequences (time series) |
|
Recommended default for time series |
|
Ordered sequences with complex temporal structure |
|
Combines self-attention + recurrent template; slower than |
|
Ordered sequences (lightweight alternative) |
|
Conv1D-based; faster but less expressive |
|
2D images / spatial fields |
|
ResNet-style with residual blocks |
|
Exchangeable data (simpler baseline) |
|
Faster than |
SetTransformer is for sets, not sequences. If your data has a meaningful ordering (e.g., time, spatial position), use
TimeSeriesTransformerorFusionTransformer. UsingSetTransformeron ordered data discards temporal structure.
All networks are available under bayesflow.networks and accept a summary_dim argument that controls the output dimensionality.
5.2. Sizing Guide#
5.2.1. summary_dim: how many summary statistics to learn#
As a starting heuristic, set summary_dim to 3× the number of parameters you are inferring. For example, if you estimate 5 parameters, start with summary_dim=15. Scale up if in-silico diagnostics (simulation-based calibration, posterior z-scores) show poor recovery.
5.2.2. Network capacity: Small → Base → Large → XL#
All transformer-based summary networks share the same family of hyperparameters: embed_dims, num_heads, mlp_depths, mlp_widths. The length of these tuples sets the number of attention blocks. Here are the recommended configurations (always start with Base):
SetTransformer / TimeSeriesTransformer / FusionTransformer
Size |
|
|
|
|
|---|---|---|---|---|
Small |
|
|
|
|
Base |
|
|
|
|
Large |
|
|
|
|
XL |
|
|
|
|
ConvolutionalNetwork
Size |
|
stages |
|---|---|---|
Small |
|
2 |
Base |
|
3 |
Large |
|
4 |
XL |
|
5 |
Scale up to Large or XL only if diagnostics show poor recovery after sufficient training. Oversized networks train slower and can hurt calibration on simple problems.
5.3. TimeSeriesTransformer In Depth#
The TimeSeriesTransformer is the recommended default for any ordered sequential data. It pairs self-attention blocks with a Time2Vec embedding that lets the network learn both periodic and aperiodic time patterns.
5.3.1. Key parameters#
import bayesflow as bf
# Base-size configuration
summary_net = bf.networks.TimeSeriesTransformer(
summary_dim=16, # output dimensionality; start at 3× num_params
embed_dims=(64, 64, 64), # attention key/value/query dim per block
num_heads=(4, 4, 4), # attention heads per block
mlp_depths=(2, 2, 2), # MLP layers per block
mlp_widths=(128, 128, 128), # MLP width per block
time_embed_dim=8, # dimensionality of the Time2Vec embedding
time_axis=None, # see below
dropout=0.05
)
The tuples (embed_dims, num_heads, mlp_depths, mlp_widths) must all be the same length — this length controls the number of transformer blocks.
5.3.2. Scaling up: Large configuration#
# Large configuration — only if Base shows poor recovery in diagnostics
# and does not show signs of overfitting.
summary_net = bf.networks.TimeSeriesTransformer(
summary_dim=32,
embed_dims=(128, 128, 128, 128),
num_heads=(8, 8, 8, 8),
mlp_depths=(2, 2, 2, 2),
mlp_widths=(256, 256, 256, 256),
time_embed_dim=16
)
5.3.3. time_axis: when to set it#
Set time_axis only when your simulator explicitly appends a time-index column to the data. For example, if your simulator outputs irregularly-spaced observations and includes the timestamp as one of the features:
# Simulator output: (batch, T, 3) where the last column is the observation time
# The time column lives at axis=-1 position 2
summary_net = bf.networks.TimeSeriesTransformer(
summary_dim=16,
time_axis=-1 # tells the network which column contains the time index
)
If your data is a plain value sequence — e.g., shape (batch, T, d) where every column is a measurement — leave time_axis=None. The network will assume uniform time spacing \([0, T]\) automatically. Passing a wrong time_axis that points to a data channel rather than actual timestamps will degrade performance.
5.4. Usage Examples#
# Here is a very simple simulator that generates spatial data
import numpy as np
def prior():
return dict(theta=np.random.normal(size=2)) # [mu, log_sigma]
def likelihood(theta):
mu, log_sigma = theta
sigma = np.exp(log_sigma * 0.5)
return dict(
x=np.random.normal(mu, sigma, size=(30, 1)),
y=np.random.normal(mu, sigma, size=(30, 1)),
z=np.random.normal(mu, sigma, size=(30, 1)),
)
simulator = bf.make_simulator([prior, likelihood])
5.4.1. Inside a standalone approximator#
The summary network is passed directly to the approximator. The adapter routes data to summary_variables, which the approximator feeds through the summary network before the inference network.
adapter = (
bf.Adapter()
.convert_dtype("float64", "float32")
.concatenate(["x", "y", "z"], into="summary_variables")
.rename("theta", "inference_variables")
)
approximator = bf.ContinuousApproximator(
inference_network=bf.networks.FlowMatching(),
summary_network=bf.networks.TimeSeriesTransformer(summary_dim=16),
adapter=adapter
)
approximator.compile(optimizer="adam")
history = approximator.fit(simulator=simulator, epochs=5, batch_size=64, num_batches=100)
5.4.2. Inside a workflow#
The workflow handles compilation for you. Pass summary_network directly:
workflow = bf.BasicWorkflow(
simulator=simulator,
inference_variables=["theta"],
summary_variables=["x", "y", "z"], # tells the workflow what to summarize
inference_network=bf.networks.FlowMatching(),
summary_network=bf.networks.TimeSeriesTransformer(summary_dim=16),
)
history = workflow.fit_online(epochs=5, batch_size=128, num_batches=100)
The workflow’s default adapter will automatically mark summary_variables correctly based on the names provided. See Data Processing for controlling the adapter configuration in more detail.
5.5. Custom Summary Networks#
When none of the built-in networks fit your data modality, you can write your own. Two requirements:
Inherit from
bf.networks.SummaryNetwork.Decorate with
@serializableso the network can be saved and loaded (see Saving & Loading).
5.5.1. Example: multi-channel LFP encoder#
Local field potentials (LFPs) are multi-channel neural recordings — a 3D array of shape (batch, time, n_channels). Standard summary networks treat channels uniformly. This custom encoder combines a causal Conv1D across time with a channel attention layer, giving the network an inductive bias that is tailored to spatially correlated brain signals.
import keras
from bayesflow.utils.serialization import serializable, serialize
@serializable("my_project.networks")
class LFPEncoder(bf.networks.SummaryNetwork):
"""Compresses multi-channel LFP recordings into fixed-size summaries.
Architecture:
- Shared Conv1D across time (captures local temporal dynamics)
- Bidirectional GRU (captures long-range dependencies)
- Dense channel-attention gate (weights channels by relevance)
- Global average pool + projection to summary_dim
"""
def __init__(self, summary_dim: int = 32, filters: int = 64, gru_units: int = 128, **kwargs):
super().__init__(**kwargs)
self.summary_dim = summary_dim
self.filters = filters
self.gru_units = gru_units
# Temporal feature extraction
self.conv = keras.layers.Conv1D(filters, kernel_size=5, padding="causal", activation="gelu")
self.norm = keras.layers.LayerNormalization()
# Sequence compression
self.gru = keras.layers.Bidirectional(keras.layers.GRU(gru_units, return_sequences=True))
# Channel attention: squeeze across time, then gate across channels
self.channel_gate = keras.layers.Dense(gru_units * 2, activation="sigmoid")
# Output projection
self.pool = keras.layers.GlobalAveragePooling1D()
self.proj = keras.layers.Dense(summary_dim)
def call(self, x, training=False):
"""
Parameters
----------
x : Tensor, shape (batch, time, n_channels)
Returns
-------
Tensor, shape (batch, summary_dim)
"""
h = self.conv(x)
h = self.norm(h, training=training)
h = self.gru(h, training=training) # (batch, time, gru_units*2)
# Channel attention: average over time, compute gate, broadcast back
gate = self.channel_gate(keras.ops.mean(h, axis=1, keepdims=True))
h = h * gate
h = self.pool(h) # (batch, gru_units*2)
return self.proj(h) # (batch, summary_dim)
def get_config(self):
base = super().get_config()
return base | serialize({
"summary_dim": self.summary_dim,
"filters": self.filters,
"gru_units": self.gru_units
})
Use it exactly like any built-in summary network:
approximator = bf.ContinuousApproximator(
inference_network=bf.networks.FlowMatching(),
summary_network=LFPEncoder(summary_dim=32),
adapter=adapter
)
5.5.2. Using a pretrained backbone from Keras Hub#
For image data, you can wrap any pre-trained CNN from Keras Hub as a summary network. This is useful when your simulator produces realistic-looking images (e.g., spatial maps, microscopy images, cosmological fields) and you want to leverage ImageNet-pretrained features as a starting point.
The example below wraps DenseNet-121 (model page) — a densely-connected convolutional network — as a BayesFlow summary network. The backbone extracts features; a trainable projection head maps them to summary_dim.
# pip install keras-hub
import keras_hub
@serializable("my_project.networks")
class DenseNetSummary(bf.networks.SummaryNetwork):
"""Wraps a pretrained DenseNet-121 backbone as a BayesFlow summary network.
Input shape: (batch, height, width, 3) — expects 3-channel images.
Output shape: (batch, summary_dim)
"""
def __init__(self, summary_dim: int = 32, trainable_backbone: bool = False, **kwargs):
super().__init__(**kwargs)
self.summary_dim = summary_dim
self.trainable_backbone = trainable_backbone
# Load pretrained backbone — feature extractor only (no classification head)
self.backbone = keras_hub.models.DenseNetBackbone.from_preset(
"densenet_121_imagenet"
)
self.backbone.trainable = trainable_backbone
# Global average pool + projection to summary_dim
self.pool = keras.layers.GlobalAveragePooling2D()
self.proj = keras.layers.Dense(summary_dim, activation="gelu")
self.out = keras.layers.Dense(summary_dim)
def call(self, x, training=False):
"""
Parameters
----------
x : Tensor, shape (batch, height, width, 3)
Images normalized to [0, 1] or [-1, 1].
Returns
-------
Tensor, shape (batch, summary_dim)
"""
features = self.backbone(x, training=training)
# backbone outputs (batch, H', W', C) feature maps; pool to a vector
pooled = self.pool(features)
return self.out(self.proj(pooled))
def get_config(self):
base = super().get_config()
return base | serialize({
"summary_dim": self.summary_dim,
"trainable_backbone": self.trainable_backbone
})
Usage:
# Images are of shape (batch, 128, 128, 3)
approximator = bf.ContinuousApproximator(
inference_network=bf.networks.FlowMatching(),
summary_network=DenseNetSummary(summary_dim=32, trainable_backbone=False),
adapter=adapter
)