7. Posterior Estimation for SIR-like Models#
Author: Stefan T. Radev
import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import bayesflow as bf
import keras
WARNING:root:Multiple Keras-compatible backends detected (JAX, PyTorch, TensorFlow).
Defaulting to JAX.
To override, set the KERAS_BACKEND environment variable before importing bayesflow.
See: https://keras.io/getting_started/#configuring-your-backend
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:bayesflow:Using backend 'jax'
7.1. Introduction#
In this tutorial, we will illustrate how to perform posterior inference on simple, stationary SIR-like models (complex models will be tackled in a further notebook). SIR-like models comprise suitable illustrative examples, since they generate time-series and their outputs represent the results of solving a system of ordinary differential equations (ODEs).
The details for tackling stochastic epidemiological models with neural networks are described in our corresponding paper, which you can consult for a more formal exposition and a more comprehensive treatment of neural architectures:
OutbreakFlow: Model-based Bayesian inference of disease outbreak dynamics with invertible neural networks and its application to the COVID-19 pandemics in Germany. https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1009472
Integrating artificial intelligence with mechanistic epidemiological modeling: a scoping review of opportunities and challenges. https://www.nature.com/articles/s41467-024-55461-x
7.2. Defining the Simulator #
RNG = np.random.default_rng(2026)
As described in our very first notebook, a generative model consists of a prior (encoding suitable parameter ranges) and a simulator (generating data given simulations). Our underlying model distinguishes between susceptible, \(S\), infected, \(I\), and recovered, \(R\), individuals with infection and recovery occurring at a constant transmission rate \(\lambda\) and constant recovery rate \(\mu\), respectively. The model dynamics are governed by the following system of ODEs:
with \(N = S + I + R\) denoting the total population size. For the purpose of forward inference (simulation), we will use a time step of \(dt = 1\), corresponding to daily case reports. In addition to the ODE parameters \(\lambda\) and \(\mu\), we consider a reporting delay parameter \(L\) and a dispersion parameter \(\psi\), which affect the number of reported infected individuals via a negative binomial disttribution (https://en.wikipedia.org/wiki/Negative_binomial_distribution):
In this way, we connect the latent disease model to an observation model, which renders the relationship between parameters and data a stochastic one. Note, that the observation model induces a further parameter \(\psi\), responsible for the dispersion of the noise. Finally, we will also treat the number of initially infected individuals, \(I_0\) as an unknown parameter (having its own prior distribution).
7.2.1. Prior #
We will place the following prior distributions over the five model parameters, summarized in the table below:
How did we come up with these priors? In this case, we rely on the domain expertise and previous research (https://www.science.org/doi/10.1126/science.abb9789). In addition, the new parameter \(\psi\) follows an exponential distribution, which restricts it to positive numbers. Below is the implementation of these priors:
def prior():
"""Generates a random draw from the joint prior."""
lambd = RNG.lognormal(mean=np.log(0.4), sigma=0.5)
mu = RNG.lognormal(mean=np.log(1 / 8), sigma=0.2)
D = RNG.lognormal(mean=np.log(8), sigma=0.2)
I0 = RNG.gamma(shape=2, scale=20)
psi = RNG.exponential(5)
return {"lambd": lambd, "mu": mu, "D": D, "I0": I0, "psi": psi}
7.2.2. Observation Model (Implicit Likelihood Function) #
def convert_params(mu, phi):
"""Helper function to convert mean/dispersion parameterization of a negative binomial to N and p,
as expected by numpy's negative_binomial.
See https://en.wikipedia.org/wiki/Negative_binomial_distribution#Alternative_formulations
"""
r = phi
var = mu + 1 / r * mu**2
p = (var - mu) / var
return r, 1 - p
def stationary_SIR(lambd, mu, D, I0, psi, N=83e6, T=14, eps=1e-5):
"""Performs a forward simulation from the stationary SIR model given a random draw from the prior."""
# Extract parameters and round I0 and D
I0 = np.ceil(I0)
D = int(round(D))
# Initial conditions
S, I, R = [N - I0], [I0], [0]
# Reported new cases
C = [I0]
# Simulate T-1 timesteps
for t in range(1, T + D):
# Calculate new cases
I_new = lambd * (I[-1] * S[-1] / N)
# SIR equations
S_t = S[-1] - I_new
I_t = np.clip(I[-1] + I_new - mu * I[-1], 0.0, N)
R_t = np.clip(R[-1] + mu * I[-1], 0.0, N)
# Track
S.append(S_t)
I.append(I_t)
R.append(R_t)
C.append(I_new)
reparam = convert_params(np.clip(np.array(C[D:]), 0, N) + eps, psi)
C_obs = RNG.negative_binomial(reparam[0], reparam[1])
return dict(cases=C_obs)
As you can see, in addition to the parameters, our simulator requires two further arguments: the total population size \(N\) and the time horizon \(T\). These are quantities over which we can amortize (i.e., context variables), but for this example, we will just use the population of Germany and the first two weeks of the pandemics (i.e., \(T=14\)), in the same vein as https://www.science.org/doi/10.1126/science.abb9789.
7.2.3. Loading Real Data #
We will define a simple helper function to load the actually reported cases in 2020 for the first three weeks of the Covid-19 pandemic in Germany.
def load_data():
"""Helper function to load cumulative cases and transform them to new cases."""
confirmed_cases_url = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv"
confirmed_cases = pd.read_csv(confirmed_cases_url, sep=",")
date_data_begin = datetime.date(2020, 3, 1)
date_data_end = datetime.date(2020, 3, 15)
format_date = lambda date_py: f"{date_py.month}/{date_py.day}/{str(date_py.year)[2:4]}"
date_formatted_begin = format_date(date_data_begin)
date_formatted_end = format_date(date_data_end)
cases_obs = np.array(
confirmed_cases.loc[confirmed_cases["Country/Region"] == "Germany", date_formatted_begin:date_formatted_end]
)[0]
new_cases_obs = np.diff(cases_obs)
return new_cases_obs
7.2.4. Stitiching Things Together #
We can combine the prior \(p(\theta)\) and the observation model \(p(x_{1:T}\mid\theta)\) into a joint model \(p(\theta, x_{1:T}) = p(\theta) \; p(x_{1:T}\mid\theta)\) using the make_simulator builder.
The resulting object can now generate batches of simulations.
simulator = bf.make_simulator([prior, stationary_SIR])
test_sims = simulator.sample(batch_size=2)
print(test_sims["lambd"].shape)
print(test_sims["D"].shape)
print(test_sims["cases"].shape)
(2, 1)
(2, 1)
(2, 14)
7.3. Prior Checking #
Any principled Bayesian workflow requires some prior predictive or prior pushforward checks to ensure that the prior specification is consistent with domain expertise (see https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html). The BayesFlow library provides some rudimentary visual tools for performing prior checking. For instance, we can visually inspect the joint prior in the form of bivariate plots. We can focus on particular parameter combinations, such as \(\lambda\), \(\mu\), and \(D\):
prior_samples = simulator.simulators[0].sample(1000)
grid = bf.diagnostics.plots.pairs_samples(
prior_samples, variable_keys=["lambd", "mu", "D"]
)
7.4. Defining the Adapter#
We need to ensure that the outputs of the forward model are suitable for processing with neural networks. Currently, they are not, since our data \(x_{1:T}\) consists of large integer (count) values. However, neural networks like scaled data. Furthermore, our parameters \(\theta\) exhibit widely different scales due to their prior specification and role in the simulator. Finally, BayesFlow needs to know which variables are to be inferred and which ones are to be processed by the summary network before being passed to the inference network. We handle all of these steps using an Adapter.
Since all of our parameters and observables can only take on positive values, we will apply a log plus one transform to all quantities. Note, that BayesFlow expects the following keys to be present in the final outputs of your configured simulations:
inference_variables: These are the variables we are inferring.summary_variables: These are the variables that are compressed throgh a summary network and used for inferring the inference variables.
Thus, what our approximators are learning is \(p(\text{inference variables} \mid t(\text{summary variables}))\), where \(t\) is the summary network.
adapter = (
bf.adapters.Adapter()
.convert_dtype("float64", "float32")
.as_time_series("cases")
.concatenate(["lambd", "mu", "D", "I0", "psi"], into="inference_variables")
.rename("cases", "summary_variables")
# since all our variables are non-negative (zero or larger), the next call transforms them
# to the unconstrained real space and can be back-transformed under the hood
.log(["inference_variables", "summary_variables"], p1=True)
)
adapter
Adapter([0: ConvertDType -> 1: AsTimeSeries -> 2: Concatenate(['lambd', 'mu', 'D', 'I0', 'psi'] -> 'inference_variables') -> 3: Rename('cases' -> 'summary_variables') -> 4: Log])
# Let's check out the new shapes
adapted_sims = adapter(simulator.sample(2))
print(adapted_sims["summary_variables"].shape)
print(adapted_sims["inference_variables"].shape)
(2, 14, 1)
(2, 5)
7.5. Defining the Neural Approximator #
We can now proceed to define our BayesFlow neural architecture, that is, combine a summary network with an inference network.
7.5.1. Summary Network #
Since our simulator outputs 3D tensors of shape (batch_size, T = 14, 1), we need to reduce this three-dimensional tensor into a two-dimensional tensor of shape (batch_size, summary_dim). Our model outputs are actually so simple that we could have just removed the trailing dimension of the raw outputs and simply fed the data directly to the inference network.
However, we demonstrate the use of a simple Gated Recurrent Unit (GRU) summary network. Any keras model can interact with BayesFlow by inherting from SummaryNetwork which accepts an addition stage argument indicating the mode the network is currently operating in (i.e., training vs. inference).
The custom network is decorated with a serializeable decorator, which is necessary for saving / loading the full model later on.
@bf.utils.serialization.serializable("custom")
class GRU(bf.networks.SummaryNetwork):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.gru = keras.layers.GRU(64)
self.summary_stats = keras.layers.Dense(8)
def call(self, time_series, **kwargs):
"""Compresses time_series of shape (batch_size, T, 1) into summaries of shape (batch_size, 8)."""
summary = self.gru(time_series, **kwargs)
summary = self.summary_stats(summary)
return summary
summary_net = GRU()
7.5.2. Inference Network#
As a backbone inference network, we choose the all-time classic coupling flow (i.e., a type of normalizing flow).
inference_net = bf.networks.CouplingFlow(depth=2, transform="spline")
7.5.3. Workflow#
Inference with workflows is easy. Simply provide the simulator, adapter, and network objects, and have fun! If you want to save the networks automatically after training, provide a checkpoint_filepath and an optional checkpoint_name.
workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
inference_network=inference_net,
summary_network=summary_net,
standardize=None # no need to standardize due to log-transform
)
7.6. Training #
Ready to train! Since our simulator is pretty fast, we can safely go with online training. Let’s glean the time taken for a batch of \(32\) simulations.
%%time
_ = workflow.simulate(32)
CPU times: user 395 μs, sys: 3.92 ms, total: 4.32 ms
Wall time: 3.95 ms
Not too bad! However, for the purpose of illustration, we will go with offline training using a fixed data set of 8000 simulations. This may be considered a “low simulation budget” in many settings.
7.6.1. Generating Offline Data #
training_data = workflow.simulate(6000)
validation_data = workflow.simulate(300)
We are now ready to train. If not provided, the default settings use \(100\) epochs with a batch size of \(64\). The training time for this network is below 1 minute.
history = workflow.fit_offline(
data=training_data,
epochs=100,
batch_size=64,
validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
WARNING:bayesflow:searchsorted is not yet optimized for backend 'jax'
Epoch 1/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 18s 121ms/step - loss: 7.7970 - val_loss: 1.6713
Epoch 2/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.1849 - val_loss: -0.5665
Epoch 3/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -0.6809 - val_loss: -1.3038
Epoch 4/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -0.9721 - val_loss: -1.3988
Epoch 5/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -1.5364 - val_loss: -2.0807
Epoch 6/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -1.8735 - val_loss: -2.0703
Epoch 7/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.0197 - val_loss: -2.7044
Epoch 8/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.3128 - val_loss: -2.7548
Epoch 9/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.5104 - val_loss: -1.9739
Epoch 10/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.2155 - val_loss: -1.4072
Epoch 11/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.6161 - val_loss: -3.0532
Epoch 12/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.6795 - val_loss: -2.7036
Epoch 13/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.9243 - val_loss: -3.2812
Epoch 14/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.9142 - val_loss: -2.8702
Epoch 15/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.8002 - val_loss: -3.1593
Epoch 16/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.0753 - val_loss: -3.3472
Epoch 17/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -2.9499 - val_loss: -3.3610
Epoch 18/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.1287 - val_loss: -3.4320
Epoch 19/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.1943 - val_loss: -3.0502
Epoch 20/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.2655 - val_loss: -3.6025
Epoch 21/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.3146 - val_loss: -3.2089
Epoch 22/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.3703 - val_loss: -3.6042
Epoch 23/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.3030 - val_loss: -3.3886
Epoch 24/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.2497 - val_loss: -3.3860
Epoch 25/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.3968 - val_loss: -3.6420
Epoch 26/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.4399 - val_loss: -2.9520
Epoch 27/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.4130 - val_loss: -3.6961
Epoch 28/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.6362 - val_loss: -3.1072
Epoch 29/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.6487 - val_loss: -3.9550
Epoch 30/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.7541 - val_loss: -3.6187
Epoch 31/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.8452 - val_loss: -3.9453
Epoch 32/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.8117 - val_loss: -4.1507
Epoch 33/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.8746 - val_loss: -4.1911
Epoch 34/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.9347 - val_loss: -4.1658
Epoch 35/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -3.9111 - val_loss: -4.0502
Epoch 36/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.0084 - val_loss: -4.2689
Epoch 37/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.0180 - val_loss: -4.0090
Epoch 38/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.1238 - val_loss: -4.2779
Epoch 39/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.0890 - val_loss: -4.2557
Epoch 40/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.1266 - val_loss: -4.3527
Epoch 41/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.1932 - val_loss: -4.2066
Epoch 42/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.1771 - val_loss: -4.3047
Epoch 43/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.1949 - val_loss: -4.1935
Epoch 44/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.2528 - val_loss: -4.2165
Epoch 45/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.2187 - val_loss: -3.9134
Epoch 46/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.2673 - val_loss: -4.3313
Epoch 47/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.2788 - val_loss: -4.2838
Epoch 48/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.3529 - val_loss: -4.4351
Epoch 49/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.3613 - val_loss: -4.4143
Epoch 50/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.3690 - val_loss: -4.3824
Epoch 51/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.3587 - val_loss: -4.1869
Epoch 52/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.4289 - val_loss: -4.5200
Epoch 53/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.4764 - val_loss: -4.5846
Epoch 54/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.4135 - val_loss: -4.2839
Epoch 55/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.5073 - val_loss: -4.3363
Epoch 56/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.4235 - val_loss: -4.5565
Epoch 57/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.4824 - val_loss: -4.5865
Epoch 58/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.4601 - val_loss: -4.5182
Epoch 59/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.5104 - val_loss: -4.6667
Epoch 60/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.5610 - val_loss: -4.6186
Epoch 61/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.5663 - val_loss: -4.5538
Epoch 62/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.5810 - val_loss: -4.6234
Epoch 63/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6196 - val_loss: -4.7155
Epoch 64/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.5958 - val_loss: -4.7786
Epoch 65/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6107 - val_loss: -4.7404
Epoch 66/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6264 - val_loss: -4.7237
Epoch 67/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6729 - val_loss: -4.7315
Epoch 68/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6051 - val_loss: -4.5560
Epoch 69/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6257 - val_loss: -4.7303
Epoch 70/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6602 - val_loss: -4.4617
Epoch 71/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6683 - val_loss: -4.7108
Epoch 72/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6531 - val_loss: -4.6432
Epoch 73/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6975 - val_loss: -4.6503
Epoch 74/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.6945 - val_loss: -4.7091
Epoch 75/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7196 - val_loss: -4.7082
Epoch 76/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7329 - val_loss: -4.7374
Epoch 77/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7213 - val_loss: -4.5306
Epoch 78/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7274 - val_loss: -4.6645
Epoch 79/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7604 - val_loss: -4.6483
Epoch 80/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7606 - val_loss: -4.7002
Epoch 81/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7543 - val_loss: -4.6587
Epoch 82/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7740 - val_loss: -4.7329
Epoch 83/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7821 - val_loss: -4.5897
Epoch 84/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7876 - val_loss: -4.7101
Epoch 85/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7850 - val_loss: -4.6773
Epoch 86/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8080 - val_loss: -4.7336
Epoch 87/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.7957 - val_loss: -4.7046
Epoch 88/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8040 - val_loss: -4.7477
Epoch 89/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8138 - val_loss: -4.7431
Epoch 90/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8174 - val_loss: -4.7228
Epoch 91/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8163 - val_loss: -4.7227
Epoch 92/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8177 - val_loss: -4.7122
Epoch 93/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8048 - val_loss: -4.7323
Epoch 94/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8160 - val_loss: -4.7252
Epoch 95/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8167 - val_loss: -4.7314
Epoch 96/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8183 - val_loss: -4.7373
Epoch 97/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8338 - val_loss: -4.7301
Epoch 98/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8429 - val_loss: -4.7363
Epoch 99/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8292 - val_loss: -4.7355
Epoch 100/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: -4.8364 - val_loss: -4.7364
INFO:bayesflow:Training completed in 34.46 seconds.
7.6.2. Inspecting the Loss #
Following our online simulation-based training, we can quickly visualize the loss trajectory using the plots.loss function from the diagnostics module.
f = bf.diagnostics.plots.loss(history)
Great, it seems that our approximator has converged! Before we get too excited and throw our networks at real data, we need to make sure that they meet our expectations in silico, that is, given the small world of simulations the networks have seen during training.
7.7. Validation Phase#
When it comes to validating posterior inference, we can either deploy manual diagnostics from the diagnostics module, or use the automated functions from the BasicWorkflow object. First, we demonstrate manual validation.
# Set the number of posterior draws you want to get
num_datasets = 300
num_samples = 1000
# Simulate 300 scenarios
test_sims = workflow.simulate(num_datasets)
# Obtain num_samples posterior samples per scenario
# batch_size can be adjusted based on your hardware capabilities
samples = workflow.sample(conditions=test_sims, num_samples=num_samples, batch_size=50)
INFO:bayesflow:Sampling completed in 6.98 seconds.
7.7.1. Simulation-Based Calibration - Rank Histograms#
As a further small world (i.e., before real data) sanity check, we can also test the calibration of the amortizer through simulation-based calibration (SBC). See the corresponding paper for more details (https://arxiv.org/pdf/1804.06788.pdf). Accordingly, we expect to observe approximately uniform rank statistic histograms. In the present case, this is indeed what we get:
f = bf.diagnostics.plots.calibration_histogram(samples, test_sims)
WARNING:bayesflow:The ratio of simulations / posterior draws should be > 20 for reliable variance reduction, but your ratio is 0. Confidence intervals might be unreliable!
7.7.2. Simulation-Based Calibration - Rank ECDF#
For models with many parameters, inspecting many histograms can become unwieldly. Moreover, the num_bins hyperparameter for the construction of SBC rank histograms can be hard to choose. An alternative diagnostic approach for calibration is through empirical cumulative distribution functions (ECDF) of rank statistics. You can read more about this approach in the corresponding paper (https://arxiv.org/abs/2103.10522).
In order to inspect the ECDFs of marginal distributions, we will simulate \(300\) new pairs of simulated data and generating parameters \((\boldsymbol{x}, \boldsymbol{\theta})\) and use the function plots.calibration_ecdf from the diagnostics module:
f = bf.diagnostics.plots.calibration_ecdf(samples, test_sims, difference=True)
7.7.3. Inferential Adequacy (Global)#
Depending on the application, it might be interesting to see how well summaries of the full posterior (e.g., means, medians) recover the assumed true parameter values. We can test this in silico via the plots.recovery function in the diagnostics module. For instance, we can compare how well posterior means recover the true parameter (i.e., posterior z-score, https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html):
f = bf.diagnostics.plots.recovery(samples, test_sims)
Interestingly, it seems that the parameters \(\theta_1 = \mu\) and \(\theta_2 = D\) have not been learned properly as they are estimated roughly the same for every simulated datset used during testing. For some models, this might indicate that the the network training had partially failed; and we would have to train longer or adjust the network architecture. For this specific model, however, the reason is different: From the provided observables, these parameters are actually not identified so cannot be learned consistently, no matter the kind of approximator we would use.
7.7.4. Automatic Diagnostics#
The basic workflow object wraps together a bunch of useful functions that can be called automatically. For instance, we can easily obtain numerical error estimates for the big three: normalized roor mean square error (NRMSE), posterior contraction, and calibration, for \(300\) new data sets:
metrics = workflow.compute_default_diagnostics(test_data=test_sims, samples=samples)
metrics
| lambd | mu | D | I0 | psi | |
|---|---|---|---|---|---|
| NRMSE | 0.071496 | 0.241683 | 0.208013 | 0.247244 | 0.163214 |
| Log Gamma | -1.659110 | -15.135129 | 0.538009 | -4.865294 | 1.143394 |
| Calibration Error | 0.028860 | 0.106404 | 0.066930 | 0.022807 | 0.008772 |
| Posterior Contraction | 0.949154 | 0.000000 | 0.334580 | 0.113567 | 0.704457 |
We can also obtain the full set of graphical diagnostics. The method below lets you control nearly all display features (can take a while):
figures = workflow.plot_default_diagnostics(
test_data=test_sims,
samples=samples,
# optional kwargs for the different diagnostic plots can be passed as follows
loss_kwargs={"figsize": (15, 3), "label_fontsize": 12},
recovery_kwargs={"figsize": (15, 3), "label_fontsize": 12},
calibration_ecdf_kwargs={"figsize": (15, 3), "legend_fontsize": 8, "label_fontsize": 12},
coverage_kwargs={"figsize": (15, 3), "legend_fontsize": 8, "label_fontsize": 12},
z_score_contraction_kwargs={"figsize": (15, 3), "label_fontsize": 12}
)
7.8. Inference Phase #
We can now move on to using real data. This is easy, and since we are using an adapter, the same transformations applied during training will be applied during the inference phase.
# Our real-data loader returns the time series as a 1D array
obs_cases = load_data()
# Note that we transform the 1D array into shape (1, T), indicating one time series
samples = workflow.sample(conditions={"cases": obs_cases[None, :]}, num_samples=num_samples)
# Convert into a nice format 2D data frame
samples = workflow.samples_to_data_frame(samples)
samples.head()
INFO:bayesflow:Sampling completed in 5.76 seconds.
| lambd | mu | D | I0 | psi | |
|---|---|---|---|---|---|
| 0 | 0.352149 | 0.093297 | 7.338632 | 32.597759 | 4.792953 |
| 1 | 0.362652 | 0.119265 | 7.344949 | 36.322720 | 14.025829 |
| 2 | 0.454869 | 0.174293 | 8.934675 | 18.295586 | 7.282802 |
| 3 | 0.365535 | 0.108433 | 6.161946 | 45.498047 | 5.797621 |
| 4 | 0.375026 | 0.133876 | 7.220122 | 35.323311 | 13.098855 |
7.8.1. Posterior Retrodictive Checks #
These are also called posterior predictive checks, but here we want to explicitly highlight the fact that we are not predicting future data but testing the generative performance or re-simulation performance of the model. In other words, we want to test how well the simulator can reproduce the actually observed data given the parameter posterior \(p(\theta \mid h(x_{1:T}))\).
Here, we will create a custom function which plots the observed data and then overlays draws from the posterior predictive.
def plot_ppc(samples, obs_cases, logscale=True, color="#132a70", figsize=(12, 6), font_size=18):
"""
Helper function to perform some plotting of the posterior predictive.
"""
# Plot settings
plt.rcParams["font.size"] = font_size
f, ax = plt.subplots(1, 1, figsize=figsize)
T = len(obs_cases)
# Re-simulations
sims = []
for i in range(samples.shape[0]):
# Note - simulator returns 2D arrays of shape (T, 1), so we remove trailing dim
sim_cases = stationary_SIR(*samples.values[i])
sims.append(sim_cases["cases"])
sims = np.array(sims)
# Compute quantiles for each t = 1,...,T
qs_50 = np.quantile(sims, q=[0.25, 0.75], axis=0)
qs_90 = np.quantile(sims, q=[0.05, 0.95], axis=0)
qs_95 = np.quantile(sims, q=[0.025, 0.975], axis=0)
# Plot median predictions and observed data
ax.plot(np.median(sims, axis=0), label="Median predicted cases", color=color)
ax.plot(obs_cases, marker="o", label="Reported cases", color="black", linestyle="dashed", alpha=0.8)
# Add compatibility intervals (also called credible intervals)
ax.fill_between(range(T), qs_50[0], qs_50[1], color=color, alpha=0.5, label="50% CI")
ax.fill_between(range(T), qs_90[0], qs_90[1], color=color, alpha=0.3, label="90% CI")
ax.fill_between(range(T), qs_95[0], qs_95[1], color=color, alpha=0.1, label="95% CI")
# Grid and schmuck
ax.grid(color="grey", linestyle="-", linewidth=0.25, alpha=0.5)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xlabel("Days since pandemic onset")
ax.set_ylabel("Number of cases")
ax.minorticks_off()
if logscale:
ax.set_yscale("log")
ax.legend(fontsize=font_size)
return f
We can now go on and plot the re-simulations (i.e., perform a posterior “predictive” check on the fitted data):
f = plot_ppc(samples, obs_cases)
That’s it for this tutorial! You now know how to use the basic building blocks of BayesFlow to create amortized neural approximators. :)