7. Rapid Iteration with Point Estimation - Lotka-Volterra Dynamics#

Author: Hans Olischläger

In this notebook, we will infer parameters of a famous ecology differential equation with BayesFlow.

We will follow a typical workflow that emphazises rapid iterations early on, before building up towards reliable estimates of the full posterior with end-to-end data embedding.

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from scipy.integrate import odeint

import keras

import bayesflow as bf
2025-03-15 18:01:29.115459: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-15 18:01:29.118881: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-15 18:01:29.126772: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1742058089.139345  341836 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742058089.142967  341836 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-15 18:01:29.158501: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-03-15 18:01:30.958916: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
# avoid scientific notation for outputs
np.set_printoptions(suppress=True)

7.1. Ecology simulator#

Say we measured population counts from two species over time. One of them preys on the other, so we might assume that the dynamics are governed by the classic Lotka-Volterra system.

In dimensionless form, with prey population \(x\) and predator population \(y\), the nonlinear differential equation is

\[\begin{split} \begin{aligned}{\frac {dx}{dt}}&=\alpha x-\beta xy,\\{\frac {dy}{dt}}&=-\gamma y+\delta xy.\end{aligned} \end{split}\]

As always, this model entails a number of assumptions that can only be approximate. In brief: On their own, prey count increases exponentially with rate \(\alpha\), while predator count decays with rate \(\gamma\). Interesting dynamics are possible when both predators and prey are present: The number of predators increases the more prey it can hunt, reducing prey counts proportionally at a rate \(\beta\) and increasing predator count proportionally at a rate \(\delta\).

We can measure population timeseries, but never the parameters directly, so this is a scientifically relevant inverse problem.

The Lotka-Volterra equations alone are not yet a concrete testable hypothesis, since it does not on its own predict anything measureable. We must pick parameters, initial conditions, and an observation model which describes how measurements take place. Note: the wide applicability of simulation-based inference is due to the fact that scientific hypotheses typically come in the form of simulators of measurable quantities.

Our simulator will consist of three parts:

  1. First, we choose a prior distribution over parameters, that reflects our beliefs about parameters before observing data.

  2. Building on parameters sampled from the prior, we solve the parameterized Lotka-Volterra equation in time starting from some initial conditions.

  3. And finally, we hypothesize that we will make some counting errors when observing the populations, introducing a Gaussian error on the true populations.

A random number generator with a fixed seed will ensure reproducibility of the simulated training and validation data.

rng = np.random.default_rng(seed=1234)
def prior():
    x = rng.normal(size=4)
    theta = 1/(1+np.exp(-x)) * 3.9 + 0.1 # logit normal distribution scaled to range from 0.1 and 4
    return dict(
        alpha=theta[0],
        beta=theta[1],
        gamma=theta[2],
        delta=theta[3],
    )

def lotka_volterra_equations(state, t, alpha, beta, gamma, delta):
    x, y = state
    dxdt = alpha * x - beta * x * y
    dydt = - gamma * y + delta * x * y
    return [dxdt, dydt]

def ecology_model(alpha, beta, gamma, delta, t_span=[0, 5], t_steps=100, initial_state=[1, 1]):
    t = np.linspace(t_span[0], t_span[1], t_steps)
    state = odeint(lotka_volterra_equations, initial_state, t, args=(alpha, beta, gamma, delta))
    x, y = state.T  # Transpose to get x and y arrays
    
    return dict(
        x=x,  # Prey time series
        y=y,  # Predator time series
        t=t,  # time
    )

def observation_model(x, y, t, observation_subsampling=10, observation_probability=1, observation_noise=0.1):
    t_steps = x.shape[0]
    
    # observation noise
    observed_x = rng.normal(loc=x, scale=observation_noise)
    observed_y = rng.normal(loc=y, scale=observation_noise)
    observed_t = np.copy(t)

    # if observation_probability < 1, the population count is missing for some time steps
    random_indices = rng.choice(np.arange(0, t_steps, observation_subsampling), int(observation_probability * t_steps // observation_subsampling), replace=False)
    random_indices = np.sort(random_indices)  # rng.choice scrambles the order of observation indices
    
    return dict(
        observed_x=observed_x[random_indices],  # Prey time series
        observed_y=observed_y[random_indices],  # Predator time series
        observed_t=observed_t[random_indices],
    )

We can combine these three components into a BayesFlow simulator via:

simulator = bf.make_simulator([prior, ecology_model, observation_model])

Let’s sample 1000 trajectories, and see what we get:

num_trajectories = 1000
samples = simulator.sample(num_trajectories)
keras.tree.map_structure(keras.ops.shape, samples)
{'alpha': (1000, 1),
 'beta': (1000, 1),
 'gamma': (1000, 1),
 'delta': (1000, 1),
 'x': (1000, 100),
 'y': (1000, 100),
 't': (1000, 100),
 'observed_x': (1000, 10),
 'observed_y': (1000, 10),
 'observed_t': (1000, 10)}

What types of developments (and observations) does our Lotka-Volterra simulator predict? We should have a function to visualize sampled trajectories and take a look!

def trajectory_aggregation(traj, confidence=0.95):
    alpha = 1 - confidence
    quantiles = np.quantile(traj, [alpha/2, 0.5, 1-alpha/2], axis=0).T
    central = quantiles[:,1]
    L = quantiles[:,0]
    U = quantiles[:,2]
    return central, L, U

def plot_trajectores(samples, variable_keys, variable_names, fill_colors=["blue", "darkred"], confidence=0.95, alpha=0.8, observations=None, ax=None):
    t_span = samples["t"][0]
    
    if ax is None:
        fig, ax = plt.subplots(1, figsize=(12,3))
        sns.despine()
    
    for i, key in enumerate(variable_keys):

        if observations is not None:     
            ax.scatter(observations["observed_t"], observations["observed_"+key], color=fill_colors[i], marker="x", label="Observed " + variable_names[i].lower())

        central, L, U = trajectory_aggregation(samples[key], confidence=confidence)
        ax.plot(t_span, central, color=fill_colors[i], label="Median " + variable_names[i].lower())
        ax.fill_between(t_span, L, U, color=fill_colors[i], alpha=0.2, label=rf"{int((confidence) * 100)}$\%$ Confidence Bands")

        # plot 20 trajectory samples
        for j in range(20):
            if j == 0:
                label = f"{variable_names[i]} trajectories"
            else:
                label = None
            ax.plot(t_span, samples[key][j], color=fill_colors[i], alpha=0.2, label=label)
        

    ax.legend()
    ax.set_xlabel("t")
    ax.set_ylabel("population")

plot_trajectores(samples, ["x", "y"], ["Prey", "Predator"])
../_images/b6676f7060f964857feb5600ca259dcbf704feb5921d7681be0f4a3107594f9d.png

Above we see the prior predictive distribution of the simulator. The shaded area contains 95% of trajectories at each timestep, additionally we see a few example trajectories.

Predator and prey populations generally oscillate in this model. But the frequency, amplitude, relative lag and scale varies greatly for different parameters.

The prior predictive distribution should match our expectation of the real world system of interest before we take into account concrete observed population counts. Here, we see the prior implies population magnitudes to oscillate (mostly) below 6.

7.2. Rapid inference#

The first goal will be to get a fast but crude approximation of the true posteriors for different observations. Two ingredients will allow us to move fast towards parameter inference:

  1. basic hand crafted summary statistics

  2. point estimation

This will help us diagnose challenges with the simulator and establishes a baseline for the final goal: full posterior inference.

7.2.1. Basic hand crafted summary statistics#

Ultimately, we want to learn maximally informative summary statistics jointly with an amortized posterior approximation, but hand crafted summary statistics have the benefit of being interpretable and fast to compute. Oftentimes, there are a few natural and established statistics for a particular modality of raw data. Researchers of the field are likely to have made significant progress in finding closed form expressions or algorithms for informative summaries.

Compared to the theoretically optimal summary statistics, we can expect there to be less posterior contraction.

Still, we can reasonably expect, that oscillation period, mean, (log) variance, autocorrelation at different lags of both trajectories, and the cross correlation between the two trajectories are highly informative when taken together as summary statistics.

import scipy

def period(observed_x, t_span=[0, 5], t_steps=500):
    """
    Computes the dominant period of observed_x from a periodogram.
    """
    f, Pxx = scipy.signal.periodogram(observed_x, t_steps/(t_span[1]-t_span[0]))
    freq_dominant = f[np.argmax(Pxx)]
    T = 1 / freq_dominant
    return T


def autocorr(trajectory, lags):
    """
    Computes the autocorrelation for each specified lag in a trajectory.
    
    Parameters
    ----------
    trajectory : np.ndarray
        The time series data, assumed to be a 1D array.
    lags : np.ndarray or list
        The lags at which to compute the autocorrelation.
    
    Returns
    -------
    auto_correlation : np.ndarray
        Autocorrelation values at each specified lag.
    """
    # Calculate the mean and variance of the trajectory for normalization
    mean = np.mean(trajectory)
    var = np.var(trajectory)
    
    # Initialize an array to hold the autocorrelation values
    auto_correlation = np.zeros(len(lags))
    
    # Compute autocorrelation for each lag
    for i, lag in enumerate(lags):
        if lag == 0:
            # Autocorrelation at lag 0 is always 1
            auto_correlation[i] = 1
        elif lag >= len(trajectory):
            # If the lag is equal to or greater than the length of the trajectory, autocorrelation is undefined (set to 0)
            auto_correlation[i] = 0
        else:
            # Compute covariance and then autocorrelation
            cov = np.mean((trajectory[:-lag] - mean) * (trajectory[lag:] - mean))
            auto_correlation[i] = cov / var

    if np.any(np.isnan(auto_correlation)):
        print(auto_correlation)
            
    return auto_correlation

def crosscorr(x, y):
    """
    Computes the cross-correlation (Pearson correlation coefficient) between two trajectories at zero lag.

    Parameters
    ----------
    x : np.ndarray
        The first time series data, assumed to be a 1D array of length n.
    y : np.ndarray
        The second time series data, assumed to be a 1D array of length n.

    Returns
    -------
    float
        The cross-correlation coefficient.
    """
    # Compute the mean and standard deviation of both time series
    mean_x = np.mean(x)
    mean_y = np.mean(y)
    std_x = np.std(x)
    std_y = np.std(y)

    # Compute the covariance and the correlation coefficient
    covariance = np.mean((x - mean_x) * (y - mean_y))
    correlation = covariance / (std_x * std_y)

    return correlation

def expert_stats(observed_x, observed_y, lags=[2,5]):
    """Computes fixed size statistics for an observed population trajectory

    Parameters
    ----------
    observed_x : np.ndarray with shape (num_observations, )
    observed_y : np.ndarray with shape (num_observations, )

    Returns
    -------
    dictionary with the following keys and values
    means      : np.ndarray with shape (2,)
    log_vars   : np.ndarray with shape (2,)
    auto_corrs : np.ndarray with shape (2*num_lags,)
        auto-correlation of each timeseries at lags 0.2 and 0.4 time units
    cross_corr : np.ndarray with shape (1,)
        the cross-correlation between the two time series
    period     : np.ndarray with shape (1,)
    """
    means = np.array([observed_x.mean(), observed_y.mean()])
    log_vars = np.log(np.array([observed_x.var(), observed_y.var()]))
    auto_corrs = np.array([
        autocorr(observed_x,lags),
        autocorr(observed_y,lags),
    ]).flatten()
    cross_corr = crosscorr(observed_x, observed_y)
    T = period(observed_x)
    
    return dict(
        means=means,
        log_vars=log_vars,
        auto_corrs=auto_corrs,
        cross_corr=cross_corr,
        period=T,
    )

To compute the expert statistics we can append the expert_stats function to the simulator object.

simulator = bf.make_simulator([prior, ecology_model, observation_model, expert_stats])
samples_with_expert_stats = simulator.sample(3)
keras.tree.map_structure(keras.ops.shape, samples_with_expert_stats)
{'alpha': (3, 1),
 'beta': (3, 1),
 'gamma': (3, 1),
 'delta': (3, 1),
 'x': (3, 100),
 'y': (3, 100),
 't': (3, 100),
 'observed_x': (3, 10),
 'observed_y': (3, 10),
 'observed_t': (3, 10),
 'means': (3, 2),
 'log_vars': (3, 2),
 'auto_corrs': (3, 4),
 'cross_corr': (3, 1),
 'period': (3, 1)}

7.2.2. Point estimation#

Ultimately, we want to infer the full posterior distribution, but it can be much faster to infer point estimates of the same and already allow us to diagnose whether inference is or can be successful for a particular simulator. Thus, in the spirit of rapid iteration we will first target the posterior mean and a few quantiles.

BayesFlow provides a convenient interface for point estimation. Here is a brief explantion of the principle:

Each point estimator is obtained by minimizing the Bayes risk for a particular loss function. Depending on the loss function, the resulting estimator will faithfully estimate a different functional of the full posterior distribution.

Typically, we refer to such loss functions as scores or scoring rules for a particular probabilistic forecast, since they score forecasts of a distribution \(p(\theta|x)\) based on samples \(\theta \sim p(\theta|x)\) of that distribution. If the true forecast is the best forecast under the score, i.e. optimizes the score (uniquely), such losses are called (strictly) proper scoring rules.

  • Here is a strictly proper scoring rule that is optimal if the estimate, \(\hat \theta\), is the true mean of the posterior:

    \[L(\hat \theta, \theta; k) = | \theta - \hat \theta |^2\]

    It is the well known squared error loss!

  • Similarly, since median minimizes the expected absolute distance to \(\theta \sim p(\theta|x)\), we know that the corresponding loss is optimized by the true median of the posterior.

    \[L(\hat \theta, \theta; k) = | \theta - \hat \theta |\]
  • To estimate quantiles, the following is a strictly proper scoring rule:

    \[L(\hat \theta, \theta; \tau) = (\hat \theta - \theta)(\mathbf{1}_{\hat \theta - \theta > 0} - \tau)\]

    Here we write an indicator function as \(\mathbf{1}_{\hat \theta - \theta > 0}\) to evaluate to 1 for overestimation (positive \(\hat \theta - \theta\)) and \(0\) otherwise.

    For \(\tau=\frac 1 2\), over- or underestimating a true posterior sample \(\theta\) is weighted equally. In fact, the quantile loss with \(\tau=\frac 1 2\) is identical to the median loss (up to a scaling of \(\frac 1 2\)). For the same reasons, both estimate the median of the posterior.

    More generally, \(\tau \in (0,1)\) is the quantile level, that is the point where to evaluate the quantile function.

  • Note, that when approximating the full distribution in BayesFlow we score a probability estimate \(\hat p(\theta|x)\) with the log-score,

    \[L(\hat p(\theta|x), \theta) = \log (\hat p(\theta))\]

    which is also a strictly proper scoring rule.

  • What if you want to estimate something else? There might just be a loss function that corresponds to the estimator of exactly the quantity you are after.

    The class of functions that leads to faithful estimators is called strictly proper scoring rules. A good reference for the theory and examples is the following paper.

      Gneiting, T., & Raftery, A. E. (2007). Strictly Proper Scoring Rules, Prediction, and Estimation. Journal of the American Statistical Association, 102(477), 359–378. https://doi.org/10.1198/016214506000001437
    

If you can find a proper scoring rule for the quantity you want to estimate, implement it as a negatively-oriented loss function, inherit from the abstract ScoringRule class and you will be able to use it within BayesFlow.

adapter = (
    bf.adapters.Adapter()
    
    # convert any non-arrays to numpy arrays
    .to_array()
    
    # convert from numpy's default float64 to deep learning friendly float32
    .convert_dtype("float64", "float32")

    # drop unobserved full trajectories and raw observations
    .drop(["x", "y", "t", "observed_x", "observed_y", "observed_t"])
    
    # standardize hand-crafted statistics to zero mean and unit variance 
    .standardize()#include=["means", "log_vars", "auto_corrs", "cross_corr", "period"])
    
    # rename the variables to match the required approximator inputs
    .concatenate(["alpha", "beta", "gamma", "delta"], into="inference_variables")
    .concatenate(["means", "log_vars", "auto_corrs", "cross_corr", "period"], into="inference_conditions")

)
adapter
Adapter([0: ToArray -> 1: ConvertDType -> 2: Drop(['x', 'y', 't', 'observed_x', 'observed_y', 'observed_t']) -> 3: Standardize -> 4: Concatenate(['alpha', 'beta', 'gamma', 'delta'] -> 'inference_variables') -> 5: Concatenate(['means', 'log_vars', 'auto_corrs', 'cross_corr', 'period'] -> 'inference_conditions')])
num_training_batches = 512
num_validation_batches = 128
batch_size = 64
epochs = 10
num_training_batches * batch_size
32768
%%time
training_data = simulator.sample(num_training_batches * batch_size,)
validation_data = simulator.sample(num_validation_batches * batch_size,)
CPU times: user 32.1 s, sys: 336 ms, total: 32.4 s
Wall time: 32.3 s

PointInferenceNetworks are defined by the ScoringRules they use to approximate certain point estimates. Passing a dictionary of such ScoringRules will construct a corresponding feed forward model.

q_levels = np.linspace(0.1,0.9,5)

point_inference_network = bf.networks.PointInferenceNetwork(
    scores=dict(
        mean=bf.scores.MeanScore(),
        quantiles=bf.scores.QuantileScore(q_levels),
    ),
)

point_inference_workflow = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    inference_network=point_inference_network,
)
%%time
history = point_inference_workflow.fit_offline(
    training_data,
    epochs=epochs, 
    batch_size=batch_size, 
    validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - loss: 0.1912 - loss/inference_loss: 0.1912 - mean/inference_mean: 0.2326 - quantiles/inference_quantiles: 0.1497 - val_loss: 0.1242 - val_loss/inference_loss: 0.1242 - val_mean/inference_mean: 0.1398 - val_quantiles/inference_quantiles: 0.1087
Epoch 2/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1387 - loss/inference_loss: 0.1387 - mean/inference_mean: 0.1659 - quantiles/inference_quantiles: 0.1115 - val_loss: 0.1337 - val_loss/inference_loss: 0.1337 - val_mean/inference_mean: 0.1546 - val_quantiles/inference_quantiles: 0.1127
Epoch 3/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1313 - loss/inference_loss: 0.1313 - mean/inference_mean: 0.1562 - quantiles/inference_quantiles: 0.1065 - val_loss: 0.1371 - val_loss/inference_loss: 0.1371 - val_mean/inference_mean: 0.1680 - val_quantiles/inference_quantiles: 0.1062
Epoch 4/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.1267 - loss/inference_loss: 0.1267 - mean/inference_mean: 0.1501 - quantiles/inference_quantiles: 0.1034 - val_loss: 0.1605 - val_loss/inference_loss: 0.1605 - val_mean/inference_mean: 0.2026 - val_quantiles/inference_quantiles: 0.1184
Epoch 5/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1238 - loss/inference_loss: 0.1238 - mean/inference_mean: 0.1459 - quantiles/inference_quantiles: 0.1016 - val_loss: 0.1139 - val_loss/inference_loss: 0.1139 - val_mean/inference_mean: 0.1323 - val_quantiles/inference_quantiles: 0.0955
Epoch 6/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1212 - loss/inference_loss: 0.1212 - mean/inference_mean: 0.1423 - quantiles/inference_quantiles: 0.1002 - val_loss: 0.1012 - val_loss/inference_loss: 0.1012 - val_mean/inference_mean: 0.1103 - val_quantiles/inference_quantiles: 0.0922
Epoch 7/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1193 - loss/inference_loss: 0.1193 - mean/inference_mean: 0.1396 - quantiles/inference_quantiles: 0.0990 - val_loss: 0.1072 - val_loss/inference_loss: 0.1072 - val_mean/inference_mean: 0.1234 - val_quantiles/inference_quantiles: 0.0909
Epoch 8/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1177 - loss/inference_loss: 0.1177 - mean/inference_mean: 0.1373 - quantiles/inference_quantiles: 0.0981 - val_loss: 0.1230 - val_loss/inference_loss: 0.1230 - val_mean/inference_mean: 0.1430 - val_quantiles/inference_quantiles: 0.1029
Epoch 9/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1167 - loss/inference_loss: 0.1167 - mean/inference_mean: 0.1358 - quantiles/inference_quantiles: 0.0976 - val_loss: 0.1103 - val_loss/inference_loss: 0.1103 - val_mean/inference_mean: 0.1251 - val_quantiles/inference_quantiles: 0.0954
Epoch 10/10
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1161 - loss/inference_loss: 0.1161 - mean/inference_mean: 0.1349 - quantiles/inference_quantiles: 0.0973 - val_loss: 0.1229 - val_loss/inference_loss: 0.1229 - val_mean/inference_mean: 0.1458 - val_quantiles/inference_quantiles: 0.1000
CPU times: user 29.8 s, sys: 2.66 s, total: 32.4 s
Wall time: 13.8 s
f = bf.diagnostics.loss(history)
../_images/897f3070a52ac93ae9dc0462b232e8b3474825eb83b5e0ec727e31e31cec3a40.png

Training is completed after a few seconds!

Just for fun and because we can, let us save the trained point approximator to disk.

checkpoint_path = "model.keras"
keras.saving.save_model(point_inference_workflow.approximator, checkpoint_path)

Now we load the approximator again from disk and use it for inference and diagnosis below.

loaded = keras.saving.load_model(checkpoint_path)
point_inference_workflow.approximator = loaded
/home/ho/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py:734: UserWarning: `compile()` was not called as part of model loading because the model's `compile()` method is custom. All subclassed Models that have `compile()` overridden should also override `get_compile_config()` and `compile_from_config(config)`. Alternatively, you can call `compile()` manually after loading.
  instance.compile_from_config(compile_config)

7.2.2.1. Inference#

The computational cost we have payed for training up front is amortized by cheap inference on simulated or measured observations. This means, we can rapidly evaluate posteriors for different observations not seen in training, which allows for comprehensive diagnosis of posterior quality.

So far so general, but point estimators in particular give a speed advantage not only in training, but also with respect to diagnostics. Since one point estimate already summarizes many posterior samples, we only have to do one forward pass with a point inference network, where we would have to make ~100 passes with a generative, full posterior approximator.

# Simulate validation data
val_sims = simulator.sample(500)

# estimate posteriors for all conditions 
estimates_point = point_inference_workflow.approximator.estimate(conditions=val_sims)

# `approximator.estimate()` returned a nested dictionary of point estimates for each named parameter,
# see the structure and shape below
keras.tree.map_structure(keras.ops.shape, estimates_point)
{'alpha': {'mean': (500, 1), 'quantiles': (500, 5, 1)},
 'beta': {'mean': (500, 1), 'quantiles': (500, 5, 1)},
 'gamma': {'mean': (500, 1), 'quantiles': (500, 5, 1)},
 'delta': {'mean': (500, 1), 'quantiles': (500, 5, 1)}}

7.2.2.2. Recovery and calibration diagnostics for point estimates#

Diagnosing problems with point estimation is done similarly to full posterior approximation. For example, you can check how point estimates relate to ground truth values with a recovery plot. The recovery plot can be used for many different point estimates. Just define which point estimate is displayed with what kind of matplotlib marker in a dictionary.

marker_mapping = dict(quantiles="_", mean="*")

Above we defined horizontal bars to indicate quantile estimates and a star to indicate the estimated mean. Point estimates for the same condition are connected with a line.

We can provide pretty names to plotting functions so we define them once here:

par_names = [r"$\alpha$", r"$\beta$", r"$\gamma$", r"$\delta$"]
f = bf.diagnostics.plots.recovery_from_estimates(
    estimates=estimates_point,
    targets=val_sims,
    marker_mapping=marker_mapping,
    s=50,  # size of markers as in matplotlib.scatter
    variable_names=par_names,
)
../_images/da697c20a59090dbacd825b2f4bed0ec416e5f09527becec34ac4a8afd438a15.png

We can and should also perform simulation based calibration checks on the estimated quantiles.

bf.diagnostics.plots.calibration_ecdf_from_quantiles(
    estimates=estimates_point, 
    targets=val_sims,
    quantile_levels=q_levels,
    difference=True,
    variable_names=par_names,
)
plt.show()
../_images/66efa5017b18455501b56055826e5fca636904a4e0c926c72d48708fd26a7ada.png

Neither the recovery nor the calibration diagnostic indicates any problems with the point inferences. Let us go one step further in validation by checking the posterior predictive distribution.

7.2.2.3. Posterior predictive check from quantile estimates#

To sample the posterior we need to assume some concrete probability function. We will choose a diagonal multivariate normal distribution that we construct to be consistent with the quantile estimates.

More concretely, we calculate a mean and standard deviation for every parameter based on its outer most quantile estimates, that is quantile level 0.1 and 0.9.

We start by extracting the lower and upper bound from the quantile posterior approximation:

post_bounds_from_quantiles = keras.tree.map_structure(lambda v: v[:,[0,-1]], {k:v["quantiles"] for k,v in estimates_point.items()})

To translate these estimates to a corresponding mean and standard deviation we consider first the standard normal distribution. We know that we are interested in a translated and scaled version of it and since this is a linear transformation, we can calculate interpolation values \(\alpha\) and \(\beta\) on the standard normal distribution and obtain mean and standard deviation for the normal distribution of interest.

If \(X\) follows a standard normal distribution with the known cumulative distribution function \(F_X(x)\), the quantile for the quantile level \(\tau_i\) is \(\tilde q_i = F_X^{-1}(\tau_{i})\) and we can compute it for both quantile levels corresponding to the bounds we computed above.

# translate two quantile levels (first and last) to quantiles on the standard normal (mean=0, std=1)
stdnormal_q = scipy.stats.norm.ppf(q_levels[[0,-1]])
stdnormal_q
array([-1.28155157,  1.28155157])

In relation to \(q_1\) and \(q_2\), where is \(x=0\) and \(x=1\)? These two correspond to location (mean) and scale (standard deviation) of the standard normal.

So we solve the equations

\[\begin{split} \begin{aligned} 0 &= \tilde q_1 (1-\alpha) + \tilde q_2 \alpha,\\ 1 &= \tilde q_1 (1-\beta) + \tilde q_2 \beta, \end{aligned} \end{split}\]

for \(\alpha\) and \(\beta\) and obtain

\[\begin{split} \begin{aligned} \alpha &= \frac {\tilde q_1} {\tilde q_1 - \tilde q_2},\\ \beta &= \frac {\tilde q_1 - 1} {\tilde q_1 - \tilde q_2}. \end{aligned} \end{split}\]
# calculate interpolation value for mean and standard deviation
alpha = stdnormal_q[0] / (stdnormal_q[0] - stdnormal_q[1])       # interpolation value for q=0 (mean = 0 for standard normal)
beta = (stdnormal_q[0] - 1) / (stdnormal_q[0] - stdnormal_q[1])  # interpolation value for q=1 (mean+std = 1 for standard normal)

Since the standard normal and the normal distribution of interest are connected by a linear transformation, we can use the interpolation values \(\alpha\) and \(\beta\) in the mean and standard deviation consistent with the selected quantile estimates as

\[\begin{split} \begin{aligned} \mu &= \hat q_1 (1-\alpha) + \hat q_2 \alpha,\\ \sigma &= \hat q_1 (\alpha-\beta) + \hat q_2 (\beta-\alpha). \end{aligned} \end{split}\]
# interpolate between values to get scaled normal parameters
post_means_from_quantiles = keras.tree.map_structure(lambda v: v[:,0] * (1-alpha) + v[:,1] * alpha, post_bounds_from_quantiles)
post_stds_from_quantiles = keras.tree.map_structure(lambda v: v[:,0] * (alpha-beta) + v[:,1] * (beta-alpha), post_bounds_from_quantiles)

And finally we can sample from this normal distribution too.

num_samples = 1000

# sample from normal distribution consistent with quantile estimates
post_draws_from_quantiles = keras.tree.map_structure(lambda v: rng.normal(
    loc=v[:,0] * (1-alpha) + v[:,1] * alpha, 
    scale=v[:,0] * (alpha-beta) + v[:,1] * (beta-alpha),
    size=(500, num_samples))[..., None], post_bounds_from_quantiles
)

Let us take a look at a particular posterior. We could use any simulated or observed dataset now. For convenience, the BayesFlow diagnostic plots applicable to single dataset generally support passing a dataset_id to select one from the simulator output.

dataset_id = 0
g = bf.diagnostics.plots.pairs_posterior(
    estimates=post_draws_from_quantiles,
    targets=val_sims,
    dataset_id=dataset_id,
    variable_names=par_names,
)
def plot_boxes(g, boxes, dataset_id, color="blue"):
    for i,(key, box) in enumerate(boxes.items()):
        for j in range(4):
            g.axes[j,i].axvline(box[dataset_id,0,0], color=color, linestyle=":")
            g.axes[j,i].axvline(box[dataset_id,1,0], color=color, linestyle=":")
            if i != j:
                g.axes[i,j].axhline(box[dataset_id,0,0], color=color, linestyle=":")
                g.axes[i,j].axhline(box[dataset_id,1,0], color=color, linestyle=":")

plot_boxes(g, post_bounds_from_quantiles, dataset_id)
g.fig.suptitle("Posterior diagonal normal approximation", y=1.01);
../_images/f45d3bc538e934a0c512dc1ebed1ed4a4d3114576ceb4379c059ef940740c685.png

The dotted lines above are the estimated quantiles for the levels 0.1 and 0.9 and we see that the quantile based normal distribution generates consistent samples. Next, let us look at how the trajectories look like that correspond to parameters from this posterior.

def offline_posterior_sampler(post_draws, dataset_id, sample_id):
    posterior_sample_for_id = {var_key: post_draws[var_key][dataset_id, sample_id,...].squeeze() for var_key in post_draws.keys()}
    return posterior_sample_for_id
def take_dataset(sims, dataset_id):
    return {var_key: sims[var_key][dataset_id] for var_key in sims.keys()}
list_of_resimulations = []
for sample_id in range(num_samples):
    one_post_sample = offline_posterior_sampler(post_draws_from_quantiles, dataset_id, sample_id)
    list_of_resimulations.append(ecology_model(t_span=[0,20], **one_post_sample))
resimulation_samples = bf.utils.tree_stack(list_of_resimulations, axis=0)

observations = take_dataset(val_sims, dataset_id)

plot_trajectores(resimulation_samples, ["x", "y"], ["Prey", "Predator"], observations=observations)
plt.xlim(0,5)
plt.title("Trajectories from posterior predictive distribution (diagonal normal approximation)");
../_images/1c70a51002f54c569528bf2c3edfd76c70b047a2f54b71f9796973349e79167d.png
plot_trajectores(resimulation_samples, ["x", "y"], ["Prey", "Predator"], observations=observations)
plt.title("Posterior predictive forecast (diagonal normal approximation)")
plt.axvline(5, color="grey", linestyle=":");
../_images/c5c4806abbcf9bf921d361c67f19ae214a779239b41defcc4dd5bf7ec900e8e2.png

The trajectories appear to fit well to the observations. Compare this to the prior predictive distribution from above. The predictive distribution now only contains trajectories with reasonable period, lag and scale. In this sense we already were successful in updating our knowledge about possible Lotka-Volterra parameters that fit to the data.

If any issues are visible in the posterior diagnostics, we could now go back, make a change to the simulator to better match real world observations, add relevant expert statistics, or try simple learnt statistics. Then, we train and diagnose again and repeat until the point estimates seem trustworthy.

Bear in mind however, that while the approximations allowed us to iterate fast they also come with a cost. By neglecting multimodality and correlation the approximate posterior is likely to be undercontracted (overdispersed). The next sections will step by step remove those approximations. Because we already know what to expect from the model, we can move confidently towards more complicated and powerful posterior approximation methods.

7.3. Full posterior approximation#

Flow Matching is a powerful class of generative neural networks. Let try and see if we can use it as a drop-in replacement for the PointInferenceNetwork we used previously.

flow_matching = bf.networks.FlowMatching()
flow_matching_workflow = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    inference_network=flow_matching,
)

Yes, we can!

We already know how to fit such a workflow. Flow matching performs well if you train it for a while. This takes a bit of time, but we will be rewarded by a tighter posterior approximation.

epochs = 50
%%time
history = flow_matching_workflow.fit_offline(
    training_data, 
    epochs=epochs, 
    batch_size=batch_size, 
    validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 4ms/step - loss: 1.5124 - loss/inference_loss: 1.5124 - val_loss: 0.6980 - val_loss/inference_loss: 0.6980
Epoch 2/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.7534 - loss/inference_loss: 0.7534 - val_loss: 0.5995 - val_loss/inference_loss: 0.5995
Epoch 3/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.6676 - loss/inference_loss: 0.6676 - val_loss: 0.5080 - val_loss/inference_loss: 0.5080
Epoch 4/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.6240 - loss/inference_loss: 0.6240 - val_loss: 0.5995 - val_loss/inference_loss: 0.5995
Epoch 5/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.5956 - loss/inference_loss: 0.5956 - val_loss: 0.5689 - val_loss/inference_loss: 0.5689
Epoch 6/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.5793 - loss/inference_loss: 0.5793 - val_loss: 0.5531 - val_loss/inference_loss: 0.5531
Epoch 7/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.5607 - loss/inference_loss: 0.5607 - val_loss: 0.4960 - val_loss/inference_loss: 0.4960
Epoch 8/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.5426 - loss/inference_loss: 0.5426 - val_loss: 0.4603 - val_loss/inference_loss: 0.4603
Epoch 9/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.5250 - loss/inference_loss: 0.5250 - val_loss: 0.3366 - val_loss/inference_loss: 0.3366
Epoch 10/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.5262 - loss/inference_loss: 0.5262 - val_loss: 0.5738 - val_loss/inference_loss: 0.5738
Epoch 11/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.5136 - loss/inference_loss: 0.5136 - val_loss: 0.4002 - val_loss/inference_loss: 0.4002
Epoch 12/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.5051 - loss/inference_loss: 0.5051 - val_loss: 0.4794 - val_loss/inference_loss: 0.4794
Epoch 13/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4940 - loss/inference_loss: 0.4940 - val_loss: 0.3145 - val_loss/inference_loss: 0.3145
Epoch 14/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4937 - loss/inference_loss: 0.4937 - val_loss: 0.4962 - val_loss/inference_loss: 0.4962
Epoch 15/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4850 - loss/inference_loss: 0.4850 - val_loss: 0.3971 - val_loss/inference_loss: 0.3971
Epoch 16/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4830 - loss/inference_loss: 0.4830 - val_loss: 0.5649 - val_loss/inference_loss: 0.5649
Epoch 17/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4813 - loss/inference_loss: 0.4813 - val_loss: 0.3687 - val_loss/inference_loss: 0.3687
Epoch 18/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4741 - loss/inference_loss: 0.4741 - val_loss: 0.4990 - val_loss/inference_loss: 0.4990
Epoch 19/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4682 - loss/inference_loss: 0.4682 - val_loss: 0.4729 - val_loss/inference_loss: 0.4729
Epoch 20/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4674 - loss/inference_loss: 0.4674 - val_loss: 0.3830 - val_loss/inference_loss: 0.3830
Epoch 21/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4646 - loss/inference_loss: 0.4646 - val_loss: 0.4645 - val_loss/inference_loss: 0.4645
Epoch 22/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4581 - loss/inference_loss: 0.4581 - val_loss: 0.3550 - val_loss/inference_loss: 0.3550
Epoch 23/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4565 - loss/inference_loss: 0.4565 - val_loss: 0.3952 - val_loss/inference_loss: 0.3952
Epoch 24/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4476 - loss/inference_loss: 0.4476 - val_loss: 0.4066 - val_loss/inference_loss: 0.4066
Epoch 25/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4506 - loss/inference_loss: 0.4506 - val_loss: 0.3373 - val_loss/inference_loss: 0.3373
Epoch 26/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4508 - loss/inference_loss: 0.4508 - val_loss: 0.4216 - val_loss/inference_loss: 0.4216
Epoch 27/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4459 - loss/inference_loss: 0.4459 - val_loss: 0.4444 - val_loss/inference_loss: 0.4444
Epoch 28/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4420 - loss/inference_loss: 0.4420 - val_loss: 0.4631 - val_loss/inference_loss: 0.4631
Epoch 29/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4389 - loss/inference_loss: 0.4389 - val_loss: 0.3162 - val_loss/inference_loss: 0.3162
Epoch 30/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4400 - loss/inference_loss: 0.4400 - val_loss: 0.4388 - val_loss/inference_loss: 0.4388
Epoch 31/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4393 - loss/inference_loss: 0.4393 - val_loss: 0.5072 - val_loss/inference_loss: 0.5072
Epoch 32/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4321 - loss/inference_loss: 0.4321 - val_loss: 0.4542 - val_loss/inference_loss: 0.4542
Epoch 33/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4361 - loss/inference_loss: 0.4361 - val_loss: 0.5522 - val_loss/inference_loss: 0.5522
Epoch 34/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4313 - loss/inference_loss: 0.4313 - val_loss: 0.5288 - val_loss/inference_loss: 0.5288
Epoch 35/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4285 - loss/inference_loss: 0.4285 - val_loss: 0.3008 - val_loss/inference_loss: 0.3008
Epoch 36/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4277 - loss/inference_loss: 0.4277 - val_loss: 0.3166 - val_loss/inference_loss: 0.3166
Epoch 37/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4253 - loss/inference_loss: 0.4253 - val_loss: 0.3985 - val_loss/inference_loss: 0.3985
Epoch 38/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4233 - loss/inference_loss: 0.4233 - val_loss: 0.3292 - val_loss/inference_loss: 0.3292
Epoch 39/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4196 - loss/inference_loss: 0.4196 - val_loss: 0.3764 - val_loss/inference_loss: 0.3764
Epoch 40/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4198 - loss/inference_loss: 0.4198 - val_loss: 0.3052 - val_loss/inference_loss: 0.3052
Epoch 41/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4230 - loss/inference_loss: 0.4230 - val_loss: 0.3368 - val_loss/inference_loss: 0.3368
Epoch 42/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4173 - loss/inference_loss: 0.4173 - val_loss: 0.3620 - val_loss/inference_loss: 0.3620
Epoch 43/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4180 - loss/inference_loss: 0.4180 - val_loss: 0.3729 - val_loss/inference_loss: 0.3729
Epoch 44/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4191 - loss/inference_loss: 0.4191 - val_loss: 0.3472 - val_loss/inference_loss: 0.3472
Epoch 45/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4177 - loss/inference_loss: 0.4177 - val_loss: 0.2900 - val_loss/inference_loss: 0.2900
Epoch 46/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4224 - loss/inference_loss: 0.4224 - val_loss: 0.3688 - val_loss/inference_loss: 0.3688
Epoch 47/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4206 - loss/inference_loss: 0.4206 - val_loss: 0.4123 - val_loss/inference_loss: 0.4123
Epoch 48/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4196 - loss/inference_loss: 0.4196 - val_loss: 0.4254 - val_loss/inference_loss: 0.4254
Epoch 49/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4142 - loss/inference_loss: 0.4142 - val_loss: 0.3533 - val_loss/inference_loss: 0.3533
Epoch 50/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.4234 - loss/inference_loss: 0.4234 - val_loss: 0.3729 - val_loss/inference_loss: 0.3729
CPU times: user 5min 11s, sys: 29.3 s, total: 5min 40s
Wall time: 1min 43s
f = bf.diagnostics.loss(history)
../_images/e9f601af985cf23c7d561986b3717f2c7b447829d20bdaf05b8982ef214886b0.png

Sampling the flow matching approximator takes much longer than estimating with the point approximator. To save time, we restrict the number of inference conditions:

val_sims = keras.tree.map_structure(lambda v: v[:100], val_sims)
%%time
# Set the number of posterior draws you want to get
num_samples = 100

# Obtain posterior draws with the sample method
post_draws = flow_matching_workflow.sample(conditions=val_sims, num_samples=num_samples)

# post_draws is a dictionary of draws with one element per named parameters
post_draws.keys()
CPU times: user 6min 33s, sys: 38.1 s, total: 7min 11s
Wall time: 55.8 s
dict_keys(['alpha', 'beta', 'gamma', 'delta'])

Quickly training a point inference network

bf.diagnostics.plots.recovery(
    estimates=post_draws, 
    targets=val_sims,
    variable_names=par_names
)
plt.show()
../_images/21dedae29672f79f6bfa2333c7e0d66329de64da7394481abbed9f6e77ff79fc.png
bf.diagnostics.plots.calibration_ecdf(
    estimates=post_draws, 
    targets=val_sims,
    variable_names=par_names,
    difference=True,
    rank_type="distance"
)
plt.show()
../_images/a89d297af2fb4e710d1afc028624d8bc5a17c6dab6c680ee40c715ef48234ed4.png
g = bf.diagnostics.plots.pairs_posterior(
    estimates=post_draws, 
    targets=val_sims,
    dataset_id=dataset_id,
    variable_names=par_names,
)
plot_boxes(g, post_bounds_from_quantiles, dataset_id)
../_images/0ebec13a5f79650384600057fedcbd1d8a1529fffc2641fe9316bf70ca385524.png

Compared to the earlier approximate posterior draws we uncovered a strong correlation between parameters. Take a look at the marginals on the diagonal - the dotted quantile estimates still pass a visual consistency check.

list_of_resimulations = []
for sample_id in range(num_samples):
    one_post_sample = offline_posterior_sampler(post_draws, dataset_id, sample_id)
    list_of_resimulations.append(ecology_model(t_span=[0,20], **one_post_sample))
resimulation_samples = bf.utils.tree_stack(list_of_resimulations, axis=0)

observations = take_dataset(val_sims, dataset_id)

plot_trajectores(resimulation_samples, ["x", "y"], ["Prey", "Predator"], observations=observations)
plt.xlim(0,5)
plt.title("Trajectories from posterior predictive distribution");
../_images/7c213c3634ec25f5ec02df93f535e13d81223f02135874ce51b3e83e0f91de46.png
plot_trajectores(resimulation_samples, ["x", "y"], ["Prey", "Predator"], observations=observations)
plt.title("Posterior predictive forecast")
plt.axvline(5, color="grey", linestyle=":");
../_images/88a7c1a75724de01b02ec6eb2e6fa13c2637ffa230736a52b91210e9b2857f9d.png

Estimating the correlation of posterior samples has constrained the posterior predictive forecast uncertainty considerably!

bf.diagnostics.plots.z_score_contraction(
    estimates=post_draws, 
    targets=val_sims,
    variable_names=par_names
)
plt.show()
../_images/59a980f235225f7a0a34646237f4675d537bfa9f8c46ae303e5bb0493ace4683.png

7.4. End-to-end learning of summary statistics#

summary_network = bf.networks.LSTNet()  # bf.networks.TimeSeriesTransformer() is slower, with similar performance on this task
learnt_sumstat_adapter = (
    bf.adapters.Adapter()
    
    # convert any non-arrays to numpy arrays
    .to_array()
    
    # convert from numpy's default float64 to deep learning friendly float32
    .convert_dtype("float64", "float32")

    # drop unobserved full trajectories
    .drop(["x", "y", "t"])

    # drop expert statistics
    .drop(["means", "log_vars", "auto_corrs", "cross_corr", "period"])
    
    # standardize target variables to zero mean and unit variance 
    .standardize(exclude=["observed_x", "observed_y", "observed_t"])
    .as_time_series(["observed_x", "observed_y", "observed_t"])
    .standardize(include=["observed_x", "observed_y", "observed_t"], axis=(0,1)) # make sure to standardize whole timeseries
    
    # rename the variables to match the required approximator inputs
    .concatenate(["alpha", "beta", "gamma", "delta"], into="inference_variables")
    .concatenate(["observed_x", "observed_y", "observed_t"], into="summary_variables")
    #.concatenate(["means", "log_vars", "auto_corrs", "cross_corr", "period"], into="inference_conditions")

)
learnt_sumstat_adapter
Adapter([0: ToArray -> 1: ConvertDType -> 2: Drop(['x', 'y', 't']) -> 3: Drop(['means', 'log_vars', 'auto_corrs', 'cross_corr', 'period']) -> 4: Standardize(exclude=['observed_x', 'observed_y', 'observed_t']) -> 5: AsTimeSeries -> 6: Standardize(include=['observed_x', 'observed_y', 'observed_t']) -> 7: Concatenate(['alpha', 'beta', 'gamma', 'delta'] -> 'inference_variables') -> 8: Concatenate(['observed_x', 'observed_y', 'observed_t'] -> 'summary_variables')])
learnt_sumstat_workflow = bf.BasicWorkflow(
    simulator=simulator,
    adapter=learnt_sumstat_adapter,
    summary_network=summary_network,
    inference_network=bf.networks.FlowMatching(),
)
%%time
history = learnt_sumstat_workflow.fit_offline(
    training_data, 
    epochs=epochs, 
    batch_size=batch_size, 
    validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 13s 13ms/step - loss: 1.1151 - loss/inference_loss: 1.1151 - val_loss: 0.6890 - val_loss/inference_loss: 0.6890
Epoch 2/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 5s 11ms/step - loss: 0.6887 - loss/inference_loss: 0.6887 - val_loss: 0.5208 - val_loss/inference_loss: 0.5208
Epoch 3/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.5839 - loss/inference_loss: 0.5839 - val_loss: 0.4417 - val_loss/inference_loss: 0.4417
Epoch 4/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.5384 - loss/inference_loss: 0.5384 - val_loss: 0.5802 - val_loss/inference_loss: 0.5802
Epoch 5/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 5s 11ms/step - loss: 0.5033 - loss/inference_loss: 0.5033 - val_loss: 0.3872 - val_loss/inference_loss: 0.3872
Epoch 6/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 5s 11ms/step - loss: 0.4834 - loss/inference_loss: 0.4834 - val_loss: 0.4194 - val_loss/inference_loss: 0.4194
Epoch 7/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 5s 11ms/step - loss: 0.4619 - loss/inference_loss: 0.4619 - val_loss: 0.3810 - val_loss/inference_loss: 0.3810
Epoch 8/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.4601 - loss/inference_loss: 0.4601 - val_loss: 0.4967 - val_loss/inference_loss: 0.4967
Epoch 9/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.4466 - loss/inference_loss: 0.4466 - val_loss: 0.5709 - val_loss/inference_loss: 0.5709
Epoch 10/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.4404 - loss/inference_loss: 0.4404 - val_loss: 0.5583 - val_loss/inference_loss: 0.5583
Epoch 11/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.4312 - loss/inference_loss: 0.4312 - val_loss: 0.3990 - val_loss/inference_loss: 0.3990
Epoch 12/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.4273 - loss/inference_loss: 0.4273 - val_loss: 0.2578 - val_loss/inference_loss: 0.2578
Epoch 13/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 5s 11ms/step - loss: 0.4154 - loss/inference_loss: 0.4154 - val_loss: 0.4038 - val_loss/inference_loss: 0.4038
Epoch 14/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 5s 11ms/step - loss: 0.4135 - loss/inference_loss: 0.4135 - val_loss: 0.4274 - val_loss/inference_loss: 0.4274
Epoch 15/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.4022 - loss/inference_loss: 0.4022 - val_loss: 0.3588 - val_loss/inference_loss: 0.3588
Epoch 16/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.4002 - loss/inference_loss: 0.4002 - val_loss: 0.5089 - val_loss/inference_loss: 0.5089
Epoch 17/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 12ms/step - loss: 0.4008 - loss/inference_loss: 0.4008 - val_loss: 0.3254 - val_loss/inference_loss: 0.3254
Epoch 18/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3949 - loss/inference_loss: 0.3949 - val_loss: 0.4541 - val_loss/inference_loss: 0.4541
Epoch 19/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3878 - loss/inference_loss: 0.3878 - val_loss: 0.3627 - val_loss/inference_loss: 0.3627
Epoch 20/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3889 - loss/inference_loss: 0.3889 - val_loss: 0.4500 - val_loss/inference_loss: 0.4500
Epoch 21/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3863 - loss/inference_loss: 0.3863 - val_loss: 0.5011 - val_loss/inference_loss: 0.5011
Epoch 22/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3817 - loss/inference_loss: 0.3817 - val_loss: 0.3433 - val_loss/inference_loss: 0.3433
Epoch 23/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3784 - loss/inference_loss: 0.3784 - val_loss: 0.2991 - val_loss/inference_loss: 0.2991
Epoch 24/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3672 - loss/inference_loss: 0.3672 - val_loss: 0.2965 - val_loss/inference_loss: 0.2965
Epoch 25/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3663 - loss/inference_loss: 0.3663 - val_loss: 0.3939 - val_loss/inference_loss: 0.3939
Epoch 26/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3598 - loss/inference_loss: 0.3598 - val_loss: 0.4299 - val_loss/inference_loss: 0.4299
Epoch 27/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3635 - loss/inference_loss: 0.3635 - val_loss: 0.3062 - val_loss/inference_loss: 0.3062
Epoch 28/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3632 - loss/inference_loss: 0.3632 - val_loss: 0.2916 - val_loss/inference_loss: 0.2916
Epoch 29/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3559 - loss/inference_loss: 0.3559 - val_loss: 0.2309 - val_loss/inference_loss: 0.2309
Epoch 30/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3499 - loss/inference_loss: 0.3499 - val_loss: 0.4021 - val_loss/inference_loss: 0.4021
Epoch 31/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3513 - loss/inference_loss: 0.3513 - val_loss: 0.3448 - val_loss/inference_loss: 0.3448
Epoch 32/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3504 - loss/inference_loss: 0.3504 - val_loss: 0.2802 - val_loss/inference_loss: 0.2802
Epoch 33/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3514 - loss/inference_loss: 0.3514 - val_loss: 0.4074 - val_loss/inference_loss: 0.4074
Epoch 34/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3450 - loss/inference_loss: 0.3450 - val_loss: 0.3862 - val_loss/inference_loss: 0.3862
Epoch 35/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 12ms/step - loss: 0.3423 - loss/inference_loss: 0.3423 - val_loss: 0.3064 - val_loss/inference_loss: 0.3064
Epoch 36/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3395 - loss/inference_loss: 0.3395 - val_loss: 0.3225 - val_loss/inference_loss: 0.3225
Epoch 37/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3423 - loss/inference_loss: 0.3423 - val_loss: 0.2851 - val_loss/inference_loss: 0.2851
Epoch 38/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 12ms/step - loss: 0.3413 - loss/inference_loss: 0.3413 - val_loss: 0.3146 - val_loss/inference_loss: 0.3146
Epoch 39/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3345 - loss/inference_loss: 0.3345 - val_loss: 0.4345 - val_loss/inference_loss: 0.4345
Epoch 40/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3387 - loss/inference_loss: 0.3387 - val_loss: 0.2540 - val_loss/inference_loss: 0.2540
Epoch 41/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3339 - loss/inference_loss: 0.3339 - val_loss: 0.3592 - val_loss/inference_loss: 0.3592
Epoch 42/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3328 - loss/inference_loss: 0.3328 - val_loss: 0.3567 - val_loss/inference_loss: 0.3567
Epoch 43/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3334 - loss/inference_loss: 0.3334 - val_loss: 0.2760 - val_loss/inference_loss: 0.2760
Epoch 44/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3279 - loss/inference_loss: 0.3279 - val_loss: 0.3351 - val_loss/inference_loss: 0.3351
Epoch 45/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3300 - loss/inference_loss: 0.3300 - val_loss: 0.3574 - val_loss/inference_loss: 0.3574
Epoch 46/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 5s 11ms/step - loss: 0.3291 - loss/inference_loss: 0.3291 - val_loss: 0.3628 - val_loss/inference_loss: 0.3628
Epoch 47/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 5s 11ms/step - loss: 0.3267 - loss/inference_loss: 0.3267 - val_loss: 0.4426 - val_loss/inference_loss: 0.4426
Epoch 48/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 5s 11ms/step - loss: 0.3303 - loss/inference_loss: 0.3303 - val_loss: 0.1882 - val_loss/inference_loss: 0.1882
Epoch 49/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3295 - loss/inference_loss: 0.3295 - val_loss: 0.2819 - val_loss/inference_loss: 0.2819
Epoch 50/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - loss: 0.3313 - loss/inference_loss: 0.3313 - val_loss: 0.3821 - val_loss/inference_loss: 0.3821
CPU times: user 23min 35s, sys: 2min 10s, total: 25min 45s
Wall time: 4min 50s
f = bf.diagnostics.loss(history)
../_images/1a44eb9bcc5444ec142aa5e0e126edcafaf3830fabffbf429b68fdd9b00afbde.png

Note, that the loss is lower since we are learning summary statistics simultaneously. How does this translate to visual diagnostics? We can check them again by sampling the posteriors of validation simulations not seen in training.

%%time
# Set the number of posterior draws you want to get
num_samples = 100

# Obtain posterior draws with the sample method
post_draws = learnt_sumstat_workflow.sample(conditions=val_sims, num_samples=num_samples)

# post_draws is a dictionary of draws with one element per named parameters
post_draws.keys()
CPU times: user 5min 41s, sys: 47.6 s, total: 6min 28s
Wall time: 53.4 s
dict_keys(['alpha', 'beta', 'gamma', 'delta'])
bf.diagnostics.plots.recovery(
    estimates=post_draws, 
    targets=val_sims,
    variable_names=par_names
)
plt.show()
../_images/f7cf6615874599905363cf4399d8a71af2527c636d4cf5b4512bd595f3eb91b6.png
bf.diagnostics.plots.calibration_ecdf(
    estimates=post_draws, 
    targets=val_sims,
    variable_names=par_names,
    difference=True,
    rank_type="distance"
)
plt.show()
../_images/8a1e4e42e6ebd7642db8df349edcd33e2983341db2790f431fc51f888fd5e387.png
g = bf.diagnostics.plots.pairs_posterior(
    estimates=post_draws, 
    targets=val_sims,
    dataset_id=dataset_id,
    variable_names=par_names,
)
plot_boxes(g, post_bounds_from_quantiles, dataset_id)
../_images/43cfee205689bb0e6c183af975843f7252ace8b654ba44dca04e4ad2b0833f9f.png

Since the conditions changed now that we learn summaries of observations simultaneously to fitting the inference network, it is not surprising that posteriors seem to be shifted. You can compare how the new posterior samples relate to the dotted quantile estimates.

Neither expert-crafted nor jointly learnt statistics are guaranteed to be highly informative. However, to get to the global minimum of the training loss, the statistics need to be maximally informative. If architecture, training data and optimizer are well chosen, learnt summary statistics regularly outperform hand-crafted statistics.

list_of_resimulations = []
for sample_id in range(num_samples):
    one_post_sample = offline_posterior_sampler(post_draws, dataset_id, sample_id)
    list_of_resimulations.append(ecology_model(t_span=[0,20], **one_post_sample))
resimulation_samples = bf.utils.tree_stack(list_of_resimulations, axis=0)

observations = take_dataset(val_sims, dataset_id)

plot_trajectores(resimulation_samples, ["x", "y"], ["Prey", "Predator"], observations=observations)
plt.xlim(0,5)
plt.title("Trajectories from posterior predictive distribution");
../_images/147c5837c1826b9fc89d7f5aeffcf5d6e177ba019475ac3f6d978144f087299d.png
plot_trajectores(resimulation_samples, ["x", "y"], ["Prey", "Predator"], observations=observations)
plt.title("Posterior predictive forecast")
plt.axvline(5, color="grey", linestyle=":");
../_images/1235fecef38cdba27c40b7da15570d27e828147abe67fe09a7f7bf4a2914d7f3.png
bf.diagnostics.plots.z_score_contraction(
    estimates=post_draws, 
    targets=val_sims,
    variable_names=par_names
)
plt.show()
../_images/c236c071350bcb8927e716edf4cc084bd2b4a2b5289af921c1f313a28d0bc10f.png