4. Posterior Estimation for SIR-like Models#
Author: Stefan T. Radev
import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import bayesflow as bf
import keras
4.1. Introduction#
In this tutorial, we will illustrate how to perform posterior inference on simple, stationary SIR-like models (complex models will be tackled in a further notebook). SIR-like models comprise suitable illustrative examples, since they generate time-series and their outputs represent the results of solving a system of ordinary differential equations (ODEs).
The details for tackling stochastic epidemiological models with neural networks are described in our corresponding paper, which you can consult for a more formal exposition and a more comprehensive treatment of neural architectures:
OutbreakFlow: Model-based Bayesian inference of disease outbreak dynamics with invertible neural networks and its application to the COVID-19 pandemics in Germany. https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1009472
Integrating artificial intelligence with mechanistic epidemiological modeling: a scoping review of opportunities and challenges. https://www.nature.com/articles/s41467-024-55461-x
4.2. Defining the Simulator #
RNG = np.random.default_rng(2025)
As described in our very first notebook, a generative model consists of a prior (encoding suitable parameter ranges) and a simulator (generating data given simulations). Our underlying model distinguishes between susceptible, \(S\), infected, \(I\), and recovered, \(R\), individuals with infection and recovery occurring at a constant transmission rate \(\lambda\) and constant recovery rate \(\mu\), respectively. The model dynamics are governed by the following system of ODEs:
with \(N = S + I + R\) denoting the total population size. For the purpose of forward inference (simulation), we will use a time step of \(dt = 1\), corresponding to daily case reports. In addition to the ODE parameters \(\lambda\) and \(\mu\), we consider a reporting delay parameter \(L\) and a dispersion parameter \(\psi\), which affect the number of reported infected individuals via a negative binomial disttribution (https://en.wikipedia.org/wiki/Negative_binomial_distribution):
In this way, we connect the latent disease model to an observation model, which renders the relationship between parameters and data a stochastic one. Note, that the observation model induces a further parameter \(\psi\), responsible for the dispersion of the noise. Finally, we will also treat the number of initially infected individuals, \(I_0\) as an unknown parameter (having its own prior distribution).
4.2.1. Prior #
We will place the following prior distributions over the five model parameters, summarized in the table below:
How did we come up with these priors? In this case, we rely on the domain expertise and previous research (https://www.science.org/doi/10.1126/science.abb9789). In addition, the new parameter \(\psi\) follows an exponential distribution, which restricts it to positive numbers. Below is the implementation of these priors:
def prior():
"""Generates a random draw from the joint prior."""
lambd = RNG.lognormal(mean=np.log(0.4), sigma=0.5)
mu = RNG.lognormal(mean=np.log(1 / 8), sigma=0.2)
D = RNG.lognormal(mean=np.log(8), sigma=0.2)
I0 = RNG.gamma(shape=2, scale=20)
psi = RNG.exponential(5)
return {"lambd": lambd, "mu": mu, "D": D, "I0": I0, "psi": psi}
4.2.2. Observation Model (Implicit Likelihood Function) #
def convert_params(mu, phi):
"""Helper function to convert mean/dispersion parameterization of a negative binomial to N and p,
as expected by numpy's negative_binomial.
See https://en.wikipedia.org/wiki/Negative_binomial_distribution#Alternative_formulations
"""
r = phi
var = mu + 1 / r * mu**2
p = (var - mu) / var
return r, 1 - p
def stationary_SIR(lambd, mu, D, I0, psi, N=83e6, T=14, eps=1e-5):
"""Performs a forward simulation from the stationary SIR model given a random draw from the prior."""
# Extract parameters and round I0 and D
I0 = np.ceil(I0)
D = int(round(D))
# Initial conditions
S, I, R = [N - I0], [I0], [0]
# Reported new cases
C = [I0]
# Simulate T-1 timesteps
for t in range(1, T + D):
# Calculate new cases
I_new = lambd * (I[-1] * S[-1] / N)
# SIR equations
S_t = S[-1] - I_new
I_t = np.clip(I[-1] + I_new - mu * I[-1], 0.0, N)
R_t = np.clip(R[-1] + mu * I[-1], 0.0, N)
# Track
S.append(S_t)
I.append(I_t)
R.append(R_t)
C.append(I_new)
reparam = convert_params(np.clip(np.array(C[D:]), 0, N) + eps, psi)
C_obs = RNG.negative_binomial(reparam[0], reparam[1])
return dict(cases=C_obs)
As you can see, in addition to the parameters, our simulator requires two further arguments: the total population size \(N\) and the time horizon \(T\). These are quantities over which we can amortize (i.e., context variables), but for this example, we will just use the population of Germany and the first two weeks of the pandemics (i.e., \(T=14\)), in the same vein as https://www.science.org/doi/10.1126/science.abb9789.
4.2.3. Loading Real Data #
We will define a simple helper function to load the actually reported cases in 2020 for the first three weeks of the Covid-19 pandemic in Germany.
def load_data():
"""Helper function to load cumulative cases and transform them to new cases."""
confirmed_cases_url = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv"
confirmed_cases = pd.read_csv(confirmed_cases_url, sep=",")
date_data_begin = datetime.date(2020, 3, 1)
date_data_end = datetime.date(2020, 3, 15)
format_date = lambda date_py: f"{date_py.month}/{date_py.day}/{str(date_py.year)[2:4]}"
date_formatted_begin = format_date(date_data_begin)
date_formatted_end = format_date(date_data_end)
cases_obs = np.array(
confirmed_cases.loc[confirmed_cases["Country/Region"] == "Germany", date_formatted_begin:date_formatted_end]
)[0]
new_cases_obs = np.diff(cases_obs)
return new_cases_obs
4.2.4. Stitiching Things Together #
We can combine the prior \(p(\theta)\) and the observation model \(p(x_{1:T}\mid\theta)\) into a joint model \(p(\theta, x_{1:T}) = p(\theta) \; p(x_{1:T}\mid\theta)\) using the make_simulator builder.
The resulting object can now generate batches of simulations.
simulator = bf.make_simulator([prior, stationary_SIR])
test_sims = simulator.sample(batch_size=2)
print(test_sims["lambd"].shape)
print(test_sims["D"].shape)
print(test_sims["cases"].shape)
(2, 1)
(2, 1)
(2, 14)
4.3. Prior Checking #
Any principled Bayesian workflow requires some prior predictive or prior pushforward checks to ensure that the prior specification is consistent with domain expertise (see https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html). The BayesFlow library provides some rudimentary visual tools for performing prior checking. For instance, we can visually inspect the joint prior in the form of bivariate plots. We can focus on particular parameter combinations, such as \(\lambda\), \(\mu\), and \(D\):
prior_samples = simulator.simulators[0].sample(1000)
grid = bf.diagnostics.plots.pairs_samples(
prior_samples, variable_keys=["lambd", "mu", "D"]
)
4.4. Defining the Adapter#
We need to ensure that the outputs of the forward model are suitable for processing with neural networks. Currently, they are not, since our data \(x_{1:T}\) consists of large integer (count) values. However, neural networks like scaled data. Furthermore, our parameters \(\theta\) exhibit widely different scales due to their prior specification and role in the simulator. Finally, BayesFlow needs to know which variables are to be inferred and which ones are to be processed by the summary network before being passed to the inference network. We handle all of these steps using an Adapter.
Since all of our parameters and observables can only take on positive values, we will apply a log plus one transform to all quantities. Note, that BayesFlow expects the following keys to be present in the final outputs of your configured simulations:
inference_variables: These are the variables we are inferring.summary_variables: These are the variables that are compressed throgh a summary network and used for inferring the inference variables.
Thus, what our approximators are learning is \(p(\text{inference variables} \mid t(\text{summary variables}))\), where \(t\) is the summary network.
adapter = (
bf.adapters.Adapter()
.convert_dtype("float64", "float32")
.as_time_series("cases")
.concatenate(["lambd", "mu", "D", "I0", "psi"], into="inference_variables")
.rename("cases", "summary_variables")
# since all our variables are non-negative (zero or larger), the next call transforms them
# to the unconstrained real space and can be back-transformed under the hood
.log(["inference_variables", "summary_variables"], p1=True)
)
adapter
Adapter([0: ConvertDType -> 1: AsTimeSeries -> 2: Concatenate(['lambd', 'mu', 'D', 'I0', 'psi'] -> 'inference_variables') -> 3: Rename('cases' -> 'summary_variables') -> 4: Log])
# Let's check out the new shapes
adapted_sims = adapter(simulator.sample(2))
print(adapted_sims["summary_variables"].shape)
print(adapted_sims["inference_variables"].shape)
(2, 14, 1)
(2, 5)
4.5. Defining the Neural Approximator #
We can now proceed to define our BayesFlow neural architecture, that is, combine a summary network with an inference network.
4.5.1. Summary Network #
Since our simulator outputs 3D tensors of shape (batch_size, T = 14, 1), we need to reduce this three-dimensional tensor into a two-dimensional tensor of shape (batch_size, summary_dim). Our model outputs are actually so simple that we could have just removed the trailing dimension of the raw outputs and simply fed the data directly to the inference network.
However, we demonstrate the use of a simple Gated Recurrent Unit (GRU) summary network. Any keras model can interact with BayesFlow by inherting from SummaryNetwork which accepts an addition stage argument indicating the mode the network is currently operating in (i.e., training vs. inference).
class GRU(bf.networks.SummaryNetwork):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.gru = keras.layers.GRU(64)
self.summary_stats = keras.layers.Dense(8)
def call(self, time_series, **kwargs):
"""Compresses time_series of shape (batch_size, T, 1) into summaries of shape (batch_size, 8)."""
summary = self.gru(time_series, **kwargs)
summary = self.summary_stats(summary)
return summary
summary_net = GRU()
4.5.2. Inference Network#
As a backbone inference network, we choose the all-time classic coupling flow (i.e., a type of normalizing flow).
inference_net = bf.networks.CouplingFlow(depth=2, transform="spline")
4.5.3. Workflow#
Inference with workflows is easy. Simply provide the simulator, adapter, and network objects, and have fun! If you want to save the networks automatically after training, provide a checkpoint_filepath and an optional checkpoint_name.
workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
inference_network=inference_net,
summary_network=summary_net,
standardize=None # no need to standardize due to log-transform
)
4.6. Training #
Ready to train! Since our simulator is pretty fast, we can safely go with online training. Let’s glean the time taken for a batch of \(32\) simulations.
%%time
_ = workflow.simulate(32)
CPU times: total: 0 ns
Wall time: 7.27 ms
Not too bad! However, for the purpose of illustration, we will go with offline training using a fixed data set of 8000 simulations. This may be considered a “low simulation budget” in many settings.
4.6.1. Generating Offline Data #
training_data = workflow.simulate(6000)
validation_data = workflow.simulate(300)
We are now ready to train. If not provided, the default settings use \(100\) epochs with a batch size of \(64\). The training time for this network is below 1 minute.
history = workflow.fit_offline(
data=training_data,
epochs=100,
batch_size=64,
validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - loss: 8.4654 - val_loss: 1.5626
Epoch 2/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4086 - val_loss: -0.8353
Epoch 3/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -0.7898 - val_loss: -1.0914
Epoch 4/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -0.9806 - val_loss: -0.8046
Epoch 5/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -1.4445 - val_loss: -1.8668
Epoch 6/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -1.8533 - val_loss: -2.3139
Epoch 7/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -1.9075 - val_loss: -2.2328
Epoch 8/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -2.3124 - val_loss: -2.6770
Epoch 9/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -2.4624 - val_loss: -2.5321
Epoch 10/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -2.5059 - val_loss: -2.8379
Epoch 11/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -2.6303 - val_loss: -3.0210
Epoch 12/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -2.7221 - val_loss: -2.9149
Epoch 13/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -2.8668 - val_loss: -3.1688
Epoch 14/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -2.8628 - val_loss: -3.2402
Epoch 15/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -2.9847 - val_loss: -3.1865
Epoch 16/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -2.8777 - val_loss: -2.7351
Epoch 17/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.0363 - val_loss: -3.2453
Epoch 18/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.1559 - val_loss: -3.2152
Epoch 19/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.2761 - val_loss: -3.3717
Epoch 20/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -3.2941 - val_loss: -2.9347
Epoch 21/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.3390 - val_loss: -2.4774
Epoch 22/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.2548 - val_loss: -3.4285
Epoch 23/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.4629 - val_loss: -3.6532
Epoch 24/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.5553 - val_loss: -3.6732
Epoch 25/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.5633 - val_loss: -3.4503
Epoch 26/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.6465 - val_loss: -2.9345
Epoch 27/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.6869 - val_loss: -3.8239
Epoch 28/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.7057 - val_loss: -3.5314
Epoch 29/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.8235 - val_loss: -3.9725
Epoch 30/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.7690 - val_loss: -3.8369
Epoch 31/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -3.8363 - val_loss: -3.6539
Epoch 32/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.7696 - val_loss: -3.9771
Epoch 33/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.9930 - val_loss: -4.3570
Epoch 34/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.9420 - val_loss: -4.2805
Epoch 35/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -3.9715 - val_loss: -4.2154
Epoch 36/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.0096 - val_loss: -4.0003
Epoch 37/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.0512 - val_loss: -4.3531
Epoch 38/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.0466 - val_loss: -4.3966
Epoch 39/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.1284 - val_loss: -4.3681
Epoch 40/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.1455 - val_loss: -4.6147
Epoch 41/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.1957 - val_loss: -4.3947
Epoch 42/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.2301 - val_loss: -4.2612
Epoch 43/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.1711 - val_loss: -4.2295
Epoch 44/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.2400 - val_loss: -4.2500
Epoch 45/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.1973 - val_loss: -4.5687
Epoch 46/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.2754 - val_loss: -4.4925
Epoch 47/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.2942 - val_loss: -4.4212
Epoch 48/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.3051 - val_loss: -4.3167
Epoch 49/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.3382 - val_loss: -4.6363
Epoch 50/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.3795 - val_loss: -4.5311
Epoch 51/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.3426 - val_loss: -4.4652
Epoch 52/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.3575 - val_loss: -4.7164
Epoch 53/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.3901 - val_loss: -4.5578
Epoch 54/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.4426 - val_loss: -4.5307
Epoch 55/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.4596 - val_loss: -4.6627
Epoch 56/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.4224 - val_loss: -4.5603
Epoch 57/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.4945 - val_loss: -4.7749
Epoch 58/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.4606 - val_loss: -4.7459
Epoch 59/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: -4.5072 - val_loss: -4.7695
Epoch 60/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.5077 - val_loss: -4.7474
Epoch 61/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.5570 - val_loss: -4.7009
Epoch 62/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.5535 - val_loss: -4.6684
Epoch 63/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.5646 - val_loss: -4.5845
Epoch 64/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.5566 - val_loss: -4.7818
Epoch 65/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.5891 - val_loss: -4.7746
Epoch 66/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.5927 - val_loss: -4.7614
Epoch 67/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6112 - val_loss: -4.7501
Epoch 68/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.5921 - val_loss: -4.3004
Epoch 69/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.6026 - val_loss: -4.6233
Epoch 70/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.6264 - val_loss: -4.6281
Epoch 71/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6367 - val_loss: -4.7888
Epoch 72/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.5993 - val_loss: -4.6472
Epoch 73/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.6190 - val_loss: -4.8114
Epoch 74/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6550 - val_loss: -4.8672
Epoch 75/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6514 - val_loss: -4.7688
Epoch 76/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6911 - val_loss: -4.7451
Epoch 77/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6688 - val_loss: -4.8056
Epoch 78/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6795 - val_loss: -4.8583
Epoch 79/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6882 - val_loss: -4.6955
Epoch 80/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6950 - val_loss: -4.7902
Epoch 81/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.6990 - val_loss: -4.8877
Epoch 82/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7153 - val_loss: -4.8404
Epoch 83/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.7270 - val_loss: -4.7290
Epoch 84/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.7503 - val_loss: -4.8032
Epoch 85/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7275 - val_loss: -4.8712
Epoch 86/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7371 - val_loss: -4.7946
Epoch 87/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.7406 - val_loss: -4.8546
Epoch 88/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.7504 - val_loss: -4.8674
Epoch 89/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7553 - val_loss: -4.8885
Epoch 90/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.7628 - val_loss: -4.8377
Epoch 91/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7667 - val_loss: -4.8889
Epoch 92/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7712 - val_loss: -4.8980
Epoch 93/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.7713 - val_loss: -4.8723
Epoch 94/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7798 - val_loss: -4.8633
Epoch 95/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7745 - val_loss: -4.8896
Epoch 96/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7865 - val_loss: -4.8828
Epoch 97/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7945 - val_loss: -4.8684
Epoch 98/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.7919 - val_loss: -4.8732
Epoch 99/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: -4.7930 - val_loss: -4.8752
Epoch 100/100
94/94 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: -4.7911 - val_loss: -4.8767
INFO:bayesflow:Training completed in 1.15 minutes.
4.6.2. Inspecting the Loss #
Following our online simulation-based training, we can quickly visualize the loss trajectory using the plots.loss function from the diagnostics module.
f = bf.diagnostics.plots.loss(history)
Great, it seems that our approximator has converged! Before we get too excited and throw our networks at real data, we need to make sure that they meet our expectations in silico, that is, given the small world of simulations the networks have seen during training.
4.7. Validation Phase#
When it comes to validating posterior inference, we can either deploy manual diagnostics from the diagnostics module, or use the automated functions from the BasicWorkflow object. First, we demonstrate manual validation.
# Set the number of posterior draws you want to get
num_datasets = 300
num_samples = 1000
# Simulate 300 scenarios
test_sims = workflow.simulate(num_datasets)
# Obtain num_samples posterior samples per scenario
samples = workflow.sample(conditions=test_sims, num_samples=num_samples, batch_size=64)
INFO:bayesflow:Sampling completed in 8.60 seconds.
4.7.1. Simulation-Based Calibration - Rank Histograms#
As a further small world (i.e., before real data) sanity check, we can also test the calibration of the amortizer through simulation-based calibration (SBC). See the corresponding paper for more details (https://arxiv.org/pdf/1804.06788.pdf). Accordingly, we expect to observe approximately uniform rank statistic histograms. In the present case, this is indeed what we get:
f = bf.diagnostics.plots.calibration_histogram(samples, test_sims)
WARNING:bayesflow:The ratio of simulations / posterior draws should be > 20 for reliable variance reduction, but your ratio is 0. Confidence intervals might be unreliable!
4.7.2. Simulation-Based Calibration - Rank ECDF#
For models with many parameters, inspecting many histograms can become unwieldly. Moreover, the num_bins hyperparameter for the construction of SBC rank histograms can be hard to choose. An alternative diagnostic approach for calibration is through empirical cumulative distribution functions (ECDF) of rank statistics. You can read more about this approach in the corresponding paper (https://arxiv.org/abs/2103.10522).
In order to inspect the ECDFs of marginal distributions, we will simulate \(300\) new pairs of simulated data and generating parameters \((\boldsymbol{x}, \boldsymbol{\theta})\) and use the function plots.calibration_ecdf from the diagnostics module:
f = bf.diagnostics.plots.calibration_ecdf(samples, test_sims, difference=True)
4.7.3. Inferential Adequacy (Global)#
Depending on the application, it might be interesting to see how well summaries of the full posterior (e.g., means, medians) recover the assumed true parameter values. We can test this in silico via the plots.recovery function in the diagnostics module. For instance, we can compare how well posterior means recover the true parameter (i.e., posterior z-score, https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html):
f = bf.diagnostics.plots.recovery(samples, test_sims)
Interestingly, it seems that the parameters \(\theta_1 = \mu\) and \(\theta_2 = D\) have not been learned properly as they are estimated roughly the same for every simulated datset used during testing. For some models, this might indicate that the the network training had partially failed; and we would have to train longer or adjust the network architecture. For this specific model, however, the reason is different: From the provided observables, these parameters are actually not identified so cannot be learned consistently, no matter the kind of approximator we would use.
4.7.4. Automatic Diagnostics#
The basic workflow object wraps together a bunch of useful functions that can be called automatically. For instance, we can easily obtain numerical error estimates for the big three: normalized roor mean square error (NRMSE), posterior contraction, and calibration, for \(300\) new data sets:
metrics = workflow.compute_default_diagnostics(test_data=300)
metrics
| lambd | mu | D | I0 | psi | |
|---|---|---|---|---|---|
| NRMSE | 0.064986 | 0.196263 | 0.225614 | 0.218368 | 0.143819 |
| Log Gamma | 2.654485 | 2.377004 | 1.149616 | 1.868467 | 0.466460 |
| Calibration Error | 0.022982 | 0.009737 | 0.012105 | 0.014298 | 0.039561 |
| Posterior Contraction | 0.953850 | 0.000000 | 0.058371 | 0.391052 | 0.788854 |
We can also obtain the full set of graphical diagnostics. The method below lets you control nearly all display features (can take a while):
figures = workflow.plot_default_diagnostics(
test_data=300,
loss_kwargs={"figsize": (15, 3), "label_fontsize": 12},
recovery_kwargs={"figsize": (15, 3), "label_fontsize": 12},
calibration_ecdf_kwargs={"figsize": (15, 3), "legend_fontsize": 8, "label_fontsize": 12},
coverage_kwargs={"figsize": (15, 3), "legend_fontsize": 8, "label_fontsize": 12},
z_score_contraction_kwargs={"figsize": (15, 3), "label_fontsize": 12}
)
4.8. Inference Phase #
We can now move on to using real data. This is easy, and since we are using an adapter, the same transformations applied during training will be applied during the inference phase.
# Our real-data loader returns the time series as a 1D array
obs_cases = load_data()
# Note that we transform the 1D array into shape (1, T), indicating one time series
samples = workflow.sample(conditions={"cases": obs_cases[None, :]}, num_samples=num_samples)
# Convert into a nice format 2D data frame
samples = workflow.samples_to_data_frame(samples)
samples.head()
INFO:bayesflow:Sampling completed in 3.46 seconds.
| lambd | mu | D | I0 | psi | |
|---|---|---|---|---|---|
| 0 | 0.335902 | 0.105328 | 8.957308 | 36.724449 | 2.664177 |
| 1 | 0.384752 | 0.108980 | 8.115654 | 19.146898 | 8.231798 |
| 2 | 0.391883 | 0.106271 | 6.658990 | 26.861231 | 5.851600 |
| 3 | 0.353696 | 0.129373 | 9.651686 | 31.665882 | 4.636521 |
| 4 | 0.353486 | 0.099916 | 5.251906 | 55.420692 | 4.937239 |
4.8.1. Posterior Retrodictive Checks #
These are also called posterior predictive checks, but here we want to explicitly highlight the fact that we are not predicting future data but testing the generative performance or re-simulation performance of the model. In other words, we want to test how well the simulator can reproduce the actually observed data given the parameter posterior \(p(\theta \mid h(x_{1:T}))\).
Here, we will create a custom function which plots the observed data and then overlays draws from the posterior predictive.
def plot_ppc(samples, obs_cases, logscale=True, color="#132a70", figsize=(12, 6), font_size=18):
"""
Helper function to perform some plotting of the posterior predictive.
"""
# Plot settings
plt.rcParams["font.size"] = font_size
f, ax = plt.subplots(1, 1, figsize=figsize)
T = len(obs_cases)
# Re-simulations
sims = []
for i in range(samples.shape[0]):
# Note - simulator returns 2D arrays of shape (T, 1), so we remove trailing dim
sim_cases = stationary_SIR(*samples.values[i])
sims.append(sim_cases["cases"])
sims = np.array(sims)
# Compute quantiles for each t = 1,...,T
qs_50 = np.quantile(sims, q=[0.25, 0.75], axis=0)
qs_90 = np.quantile(sims, q=[0.05, 0.95], axis=0)
qs_95 = np.quantile(sims, q=[0.025, 0.975], axis=0)
# Plot median predictions and observed data
ax.plot(np.median(sims, axis=0), label="Median predicted cases", color=color)
ax.plot(obs_cases, marker="o", label="Reported cases", color="black", linestyle="dashed", alpha=0.8)
# Add compatibility intervals (also called credible intervals)
ax.fill_between(range(T), qs_50[0], qs_50[1], color=color, alpha=0.5, label="50% CI")
ax.fill_between(range(T), qs_90[0], qs_90[1], color=color, alpha=0.3, label="90% CI")
ax.fill_between(range(T), qs_95[0], qs_95[1], color=color, alpha=0.1, label="95% CI")
# Grid and schmuck
ax.grid(color="grey", linestyle="-", linewidth=0.25, alpha=0.5)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xlabel("Days since pandemic onset")
ax.set_ylabel("Number of cases")
ax.minorticks_off()
if logscale:
ax.set_yscale("log")
ax.legend(fontsize=font_size)
return f
We can now go on and plot the re-simulations (i.e., perform a posterior “predictive” check on the fitted data):
f = plot_ppc(samples, obs_cases)
That’s it for this tutorial! You now know how to use the basic building blocks of BayesFlow to create amortized neural approximators. :)