2. Two Moons: Tackling Bimodal Posteriors#
Authors: Lars Kühmichel, Marvin Schmitt, Valentin Pratz, Stefan T. Radev
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import bayesflow as bf
2025-02-11 16:22:30.912532: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-11 16:22:30.912881: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-11 16:22:30.914936: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-11 16:22:30.939014: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-02-11 16:22:31.425109: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2025-02-11 16:22:31.935372: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-02-11 16:22:31.935638: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2.1. Simulator#
This example will demonstrate amortized estimation of a somewhat strange Bayesian model, whose posterior evaluated at the origin \(x = (0, 0)\) of the “data” will resemble two crescent moons. The forward process is a noisy non-linear transformation on a 2D plane:
with \(x = (x_1, x_2)\) playing the role of “observables” (data to be learned from), \(\alpha \sim \text{Uniform}(-\pi/2, \pi/2)\), and \(r \sim \text{Normal}(0.1, 0.01)\) being latent variables creating noise in the data, and \(\theta = (\theta_1, \theta_2)\) being the parameters that we will later seek to infer from new \(x\). We set their priors to
This model is typically used for benchmarking simulation-based inference (SBI) methods (see https://arxiv.org/pdf/2101.04653) and any method for amortized Bayesian inference should be capable of recovering the two moons posterior without using a gazillion of simulations. Note, that this is a considerably harder task than modeling the common unconditional two moons data set used often in the context of normalizing flows.
BayesFlow offers many ways to define your data generating process. Here, we use sequential functions to build a simulator object for online training. Within this composite simulator, each function has access to the outputs of the previous functions. This effectively allows you to define any generative graph.
def theta_prior():
theta = np.random.uniform(-1, 1, 2)
return dict(theta=theta)
def forward_model(theta):
alpha = np.random.uniform(-np.pi / 2, np.pi / 2)
r = np.random.normal(0.1, 0.01)
x1 = -np.abs(theta[0] + theta[1]) / np.sqrt(2) + r * np.cos(alpha) + 0.25
x2 = (-theta[0] + theta[1]) / np.sqrt(2) + r * np.sin(alpha)
return dict(x=np.array([x1, x2]))
Within the composite simulator, every simulator has access to the outputs of the previous simulators in the list. For example, the last simulator forward_model
has access to the outputs of the three other simulators.
simulator = bf.make_simulator([theta_prior, forward_model])
Let’s generate some data to see what the simulator does:
# generate 3 random draws from the joint distribution p(r, alpha, theta, x)
sample_data = simulator.sample(3)
print("Type of sample_data:\n\t", type(sample_data))
print("Keys of sample_data:\n\t", sample_data.keys())
print("Types of sample_data values:\n\t", {k: type(v) for k, v in sample_data.items()})
print("Shapes of sample_data values:\n\t", {k: v.shape for k, v in sample_data.items()})
Type of sample_data:
<class 'dict'>
Keys of sample_data:
dict_keys(['theta', 'x'])
Types of sample_data values:
{'theta': <class 'numpy.ndarray'>, 'x': <class 'numpy.ndarray'>}
Shapes of sample_data values:
{'theta': (3, 2), 'x': (3, 2)}
BayesFlow also provides this simulator and a collection of others in the bayesflow.benchmarks
module.
2.2. Adapter#
The next step is to tell BayesFlow how to deal with all the simulated variables. You may also think of this as informing BayesFlow about the data flow, i.e., which variables go into which network and what transformations needs to be performed prior to passing the simulator outputs into the networks. This is done via an adapter layer, which is implemented as a sequence of fixed, pseudo-invertible data transforms.
Below, we define the data adapter by specifying the input and output keys and the transformations to be applied. This allows us full control over the data flow.
adapter = (
bf.adapters.Adapter()
# convert any non-arrays to numpy arrays
.to_array()
# convert from numpy's default float64 to deep learning friendly float32
.convert_dtype("float64", "float32")
# standardize target variables to zero mean and unit variance
.standardize(exclude="theta")
# rename the variables to match the required approximator inputs
.rename("theta", "inference_variables")
.rename("x", "inference_conditions")
)
adapter
Adapter([0: ToArray -> 1: ConvertDType -> 2: Standardize(exclude=['theta']) -> 3: Rename('theta' -> 'inference_variables') -> 4: Rename('x' -> 'inference_conditions')])
2.3. Dataset#
For this example, we will sample our training data ahead of time and use offline training with a very small number of epochs. In actual applications, you usually want to train much longer in order to max our performance.
num_training_batches = 512
num_validation_batches = 128
batch_size = 64
epochs = 50
training_data = simulator.sample(num_training_batches * batch_size,)
validation_data = simulator.sample(num_validation_batches * batch_size,)
2.4. Training a neural network to approximate all posteriors#
The next step is to set up the neural network that will approximate the posterior \(p(\theta\,|\,x)\).
We choose Flow Matching [1, 2] as the backbone architecture for this example, as it can deal well with the multimodal nature of the posteriors that some observables imply.
[1] Lipman, Y., Chen, R. T., Ben-Hamu, H., Nickel, M., & Le, M. Flow Matching for Generative Modeling. In The Eleventh International Conference on Learning Representations.
[2] Wildberger, J. B., Dax, M., Buchholz, S., Green, S. R., Macke, J. H., & Schölkopf, B. Flow Matching for Scalable Simulation-Based Inference. In Thirty-seventh Conference on Neural Information Processing Systems.
flow_matching = bf.networks.FlowMatching(
subnet="mlp",
subnet_kwargs={"dropout": 0.0, "widths": (256,)*6} # override default dropout = 0.05 and widths = (256,)*5
)
This inference network is just a general Flow Matching backbone, not yet adapted to the specific inference task at hand (i.e., posterior appproximation). To achieve this adaptation, we combine the network with our data adapter, which together form an approximator
. In this case, we need a ContinuousApproximator
since the target we want to approximate is the posterior of the continuous parameter vector \(\theta\).
2.4.1. Basic Workflow#
We can hide many of the traditional deep learning steps (e.g., specifying a learning rate and an optimizer) within a Workflow
object. This object just wraps everything together and includes some nice utility functions for training and in silico validation.
flow_matching_workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
inference_network=flow_matching,
)
2.4.2. Training#
We are ready to train our deep posterior approximator on the two moons example. We use the utility function fit_offline
, which wraps the approximator’s super flexible fit
method.
history = flow_matching_workflow.fit_offline(
training_data,
epochs=epochs,
batch_size=batch_size,
validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - loss: 1.5975 - loss/inference_loss: 1.5975 - val_loss: 0.3615 - val_loss/inference_loss: 0.3615
Epoch 2/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3963 - loss/inference_loss: 0.3963 - val_loss: 0.3259 - val_loss/inference_loss: 0.3259
Epoch 3/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3793 - loss/inference_loss: 0.3793 - val_loss: 0.3774 - val_loss/inference_loss: 0.3774
Epoch 4/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3734 - loss/inference_loss: 0.3734 - val_loss: 0.2816 - val_loss/inference_loss: 0.2816
Epoch 5/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3625 - loss/inference_loss: 0.3625 - val_loss: 0.2620 - val_loss/inference_loss: 0.2620
Epoch 6/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3600 - loss/inference_loss: 0.3600 - val_loss: 0.2465 - val_loss/inference_loss: 0.2465
Epoch 7/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3646 - loss/inference_loss: 0.3646 - val_loss: 0.3418 - val_loss/inference_loss: 0.3418
Epoch 8/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3593 - loss/inference_loss: 0.3593 - val_loss: 0.3830 - val_loss/inference_loss: 0.3830
Epoch 9/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3493 - loss/inference_loss: 0.3493 - val_loss: 0.4498 - val_loss/inference_loss: 0.4498
Epoch 10/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3492 - loss/inference_loss: 0.3492 - val_loss: 0.3947 - val_loss/inference_loss: 0.3947
Epoch 11/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3459 - loss/inference_loss: 0.3459 - val_loss: 0.2970 - val_loss/inference_loss: 0.2970
Epoch 12/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3457 - loss/inference_loss: 0.3457 - val_loss: 0.2243 - val_loss/inference_loss: 0.2243
Epoch 13/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3451 - loss/inference_loss: 0.3451 - val_loss: 0.4153 - val_loss/inference_loss: 0.4153
Epoch 14/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3433 - loss/inference_loss: 0.3433 - val_loss: 0.3919 - val_loss/inference_loss: 0.3919
Epoch 15/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3373 - loss/inference_loss: 0.3373 - val_loss: 0.3400 - val_loss/inference_loss: 0.3400
Epoch 16/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3440 - loss/inference_loss: 0.3440 - val_loss: 0.1885 - val_loss/inference_loss: 0.1885
Epoch 17/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3429 - loss/inference_loss: 0.3429 - val_loss: 0.3297 - val_loss/inference_loss: 0.3297
Epoch 18/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3389 - loss/inference_loss: 0.3389 - val_loss: 0.3932 - val_loss/inference_loss: 0.3932
Epoch 19/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3330 - loss/inference_loss: 0.3330 - val_loss: 0.2640 - val_loss/inference_loss: 0.2640
Epoch 20/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3272 - loss/inference_loss: 0.3272 - val_loss: 0.3262 - val_loss/inference_loss: 0.3262
Epoch 21/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3301 - loss/inference_loss: 0.3301 - val_loss: 0.4301 - val_loss/inference_loss: 0.4301
Epoch 22/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3329 - loss/inference_loss: 0.3329 - val_loss: 0.4407 - val_loss/inference_loss: 0.4407
Epoch 23/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3384 - loss/inference_loss: 0.3384 - val_loss: 0.2786 - val_loss/inference_loss: 0.2786
Epoch 24/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3283 - loss/inference_loss: 0.3283 - val_loss: 0.3840 - val_loss/inference_loss: 0.3840
Epoch 25/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3235 - loss/inference_loss: 0.3235 - val_loss: 0.3168 - val_loss/inference_loss: 0.3168
Epoch 26/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3227 - loss/inference_loss: 0.3227 - val_loss: 0.2289 - val_loss/inference_loss: 0.2289
Epoch 27/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3367 - loss/inference_loss: 0.3367 - val_loss: 0.2283 - val_loss/inference_loss: 0.2283
Epoch 28/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3331 - loss/inference_loss: 0.3331 - val_loss: 0.3331 - val_loss/inference_loss: 0.3331
Epoch 29/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3281 - loss/inference_loss: 0.3281 - val_loss: 0.1447 - val_loss/inference_loss: 0.1447
Epoch 30/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3228 - loss/inference_loss: 0.3228 - val_loss: 0.2868 - val_loss/inference_loss: 0.2868
Epoch 31/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3226 - loss/inference_loss: 0.3226 - val_loss: 0.2819 - val_loss/inference_loss: 0.2819
Epoch 32/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3167 - loss/inference_loss: 0.3167 - val_loss: 0.3676 - val_loss/inference_loss: 0.3676
Epoch 33/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3264 - loss/inference_loss: 0.3264 - val_loss: 0.2303 - val_loss/inference_loss: 0.2303
Epoch 34/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3141 - loss/inference_loss: 0.3141 - val_loss: 0.2125 - val_loss/inference_loss: 0.2125
Epoch 35/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3096 - loss/inference_loss: 0.3096 - val_loss: 0.2754 - val_loss/inference_loss: 0.2754
Epoch 36/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3187 - loss/inference_loss: 0.3187 - val_loss: 0.3006 - val_loss/inference_loss: 0.3006
Epoch 37/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3157 - loss/inference_loss: 0.3157 - val_loss: 0.3113 - val_loss/inference_loss: 0.3113
Epoch 38/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3212 - loss/inference_loss: 0.3212 - val_loss: 0.4190 - val_loss/inference_loss: 0.4190
Epoch 39/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3162 - loss/inference_loss: 0.3162 - val_loss: 0.3351 - val_loss/inference_loss: 0.3351
Epoch 40/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3159 - loss/inference_loss: 0.3159 - val_loss: 0.3813 - val_loss/inference_loss: 0.3813
Epoch 41/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3134 - loss/inference_loss: 0.3134 - val_loss: 0.4158 - val_loss/inference_loss: 0.4158
Epoch 42/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3142 - loss/inference_loss: 0.3142 - val_loss: 0.5772 - val_loss/inference_loss: 0.5772
Epoch 43/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3136 - loss/inference_loss: 0.3136 - val_loss: 0.3419 - val_loss/inference_loss: 0.3419
Epoch 44/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.3104 - loss/inference_loss: 0.3104 - val_loss: 0.5761 - val_loss/inference_loss: 0.5761
Epoch 45/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3166 - loss/inference_loss: 0.3166 - val_loss: 0.2678 - val_loss/inference_loss: 0.2678
Epoch 46/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3197 - loss/inference_loss: 0.3197 - val_loss: 0.3576 - val_loss/inference_loss: 0.3576
Epoch 47/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3165 - loss/inference_loss: 0.3165 - val_loss: 0.3637 - val_loss/inference_loss: 0.3637
Epoch 48/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3145 - loss/inference_loss: 0.3145 - val_loss: 0.3348 - val_loss/inference_loss: 0.3348
Epoch 49/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.3104 - loss/inference_loss: 0.3104 - val_loss: 0.3068 - val_loss/inference_loss: 0.3068
Epoch 50/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.3090 - loss/inference_loss: 0.3090 - val_loss: 0.4407 - val_loss/inference_loss: 0.4407
2.5. Swapping Inference Networks #
Using BayesFlow, it is easy to switch to a different backbone architecture for the inference network. For instance, the code below demonstrates the use of a Consistency Model, which can allow for faster sampling during inference.
2.5.1. Consistency Models: Background#
Consistency Models (CM; [1]) leverage the nice properties of score-based diffusion to enable few-step sampling. Score-based diffusion initially relied on a stochastic differential equation (SDE) for sampling, but there is also a ordinary (non-stochastic) differential equation (ODE)that has the same marginal distribution at each time step \(t\) [2]. This means that even though SDE and ODE produce different paths from the noise distribution to the target distribution, the resulting distributions when looking at many paths at time \(t\) is the same. The ODE is also called Probability Flow ODE.
CMs leverage the fact that there is no randomness in the ODE formulation. That means, if you start at a certain point in the latent space, you will always take the same path and end up at the same point in the target \(\theta\)-space. The same is true for every point on the path: if you integrate to get to time \(t=0\), you will end up at the same point as well. In short: for each path, there is exactly one corresponding point in latent space (at \(t=T\)) and one corresponding point in data space (at \(t=0\)).
The goal of CMs is the following: each point at a time point \(t\) belongs to exactly one path, and we want to predict where this path will end up at \(t=0\). The function that does this is called the consistency function \(f\). If we have the correct function for all \(t\in(0,T]\), we can just sample from the latent distribution (\(t=T\)) and use \(f\) to directly map to the corresponding point at \(t=0\), which is in the target distribution. So for sampling from the target distribution, we avoid any integration and only need one evaluation of the consistency function. In practice, the one-step sampling does not work very well. Instead, we leverage a multi-step sampling method where we call \(f\) multiple times. Please check out the [1] for more background on this sampling procedure.
When reading the above, you might wonder why we also learn the mapping to \(t=0\) of all intermediate time steps \(t\in[0, T]\), and not only for \(t=T\). The main answer is that for efficient training, we do not want to actually compute the two associated points explicitly. Doing so would require to do a precise integration at training time, which is often not feasible as it is too computationally costly. Learning all time steps opens up the possibility for a different training approach where we can avoid this. The details of this become a bit more complicated, and we advise you to take a look at [1] if you are interested in a more thorough and mathematical discussion. Below, we will give a rough description of the underlying concepts.
Training First, we know that at \(t=0\), it holds that \(f(\theta,t=0)=\theta\), as \(\theta\) is part of the path that ends at \(\theta\). This boundary condition serves as an “anchor” for our training, this is the information that the network knows at the start of the training procedure (we encode it with a time-dependent skip-connection, so the network is forced to be the identity function at \(t=0\)). For training, we now somehow have to propagate this information to the rest of the part. The basic idea for this is simple. We just take a point \(\theta_1\) closer to the data distribution (smaller time \(t_1\)) and integrate for a small time step \(dt\) to a point \(\theta_2\) on the same path that is closer to the latent distribution (larger time \(t_2=t_1+dt\)). As we know that for \(t=0\) our network provides the correct output for our path, we want to propagate the information from smaller times to larger times. Our training goal is to move the output of \(f(\theta_2, t=t_2)\) towards the output of \(f(\theta_1, t=t_1)\). How to choose \(\theta_1\), \(t_1\) and \(dt\) is an empirical question, see the [1] for some thoughts on what works well.
Distilling inference In the case of distillation, we start with a trained score-based diffusion model. We can use it to integrate the Probability Flow ODE to get from \(\theta_1\) to \(\theta_2\). If we do not have such a model, it seems as if we were stuck. We do not know which points lie on the same path, so we do not know which outputs to make similar. Fortunately, it turns out that there is an unbiased approximator that, if averaged over many samples (check out the paper for the exact description), will also give us the correct score. If we use this approximator instead of the score model, and use only a single Euler step to move along the path, we get an algorithm similar to the one described for distillation. It is called Consistency Training (CT) and allows us to train a consistency model using only samples from the data distribution. The algorithm for this was improved a lot in [3], and we have incorporated those improvements into our implementation.
Improving consistency training We have made several approximations to get to a standalone consistency training algorithm. As a consequence, the introduced hyperparameters and their choice unfortunately becomes somewhat unintuitive. We have to rely on empirical observations and heuristics to see what works. This was done in [4], we encourage you to use the values provided there as starting points. If you happen to find hyperparameters that work significantly better, please let us know (e.g., by opening an issue or sending an email). This will help others to find the correct region in the hyperparameter space.
[1] Song, Y., Dhariwal, P., Chen, M., & Sutskever, I. (2023). Consistency Models. arXiv preprint. https://doi.org/10.48550/arXiv.2303.01469
[2] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. In International Conference on Learning Representations. https://openreview.net/forum?id=PxTIG12RRHS
[3] Song, Y., & Dhariwal, P. (2023). Improved Techniques for Training Consistency Models. arXiv preprint. https://doi.org/10.48550/arXiv.2310.14189
[4] Schmitt, M., Pratz, V., Köthe, U., Bürkner, P.-C., & Radev, S. T. (2024). Consistency Models for Scalable and Fast Simulation-Based Inference. arXiv preprint. https://doi.org/10.48550/arXiv.2312.05440
2.5.2. Consistency Models: Specification#
We can now go ahead and define our new inference network backbone. Apart from the usual parameters like learning rate and batch size, CMs come with a number of different hyperparameters. Unfortunately, they can heavily interact, so they can be hard to tune. The main hyperparameters are:
Maximum time
max_time
: This also serves as the standard deviation of the latent distribution. You can experiment with this, values from 10-200 seem to work well. In any case, it should be larger than the standard deviation of the target distribution.Minimum/maximum number of discretization steps during training
s0
/s1
: The effect of those is hard to grasp. 10 works well fors0
. Intuitively, increasings1
along with the number of epochs should lead to better result, but in practice we sometimes observe a breakdown for high values ofs1
. This seems to be problem-dependent, so just try it out.sigma2
modifies the time-dependency of the skip connection. Its effect on the training is unclear, we recommend leaving it at 1.0 or setting it to the approximate variance of the target distribution.Smallest time value
eps
(\(t=\epsilon\) is used instead of \(t=0\) for numerical reasons): No large effect in our experiments, as long as it is kept small enough. Probably not worth tuning.
You may find that different hyperparameter values work better for your tasks.
A short note on dropout: in our experiments, dropout usually lead to worse performance, so generally we recommend setting the droput rate to \(0.0\). Consistency training takes advantage of a noisy estimator of the score, so probably the training is already sufficiently noisy and extra dropout for regularization is not necessary.
# Compute the empirical variance of the draws from the prior θ ~ p(θ)
consistency_model = bf.networks.ConsistencyModel(
subnet="mlp",
subnet_kwargs={"dropout": 0.0, "widths": (256,)*6},
total_steps=num_training_batches*epochs,
max_time=10, # this probably needs to be tuned for a novel application
sigma2=1.0, # the data adapter standardizes our parameters, so set to 1.0
)
# Workflow for consistency model
consistency_model_workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
inference_network=consistency_model,
)
2.5.3. Consistency Training#
history = consistency_model_workflow.fit_offline(
training_data,
epochs=epochs,
batch_size=batch_size,
validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 3s 4ms/step - loss: 0.3549 - loss/inference_loss: 0.3549 - val_loss: 0.2793 - val_loss/inference_loss: 0.2793
Epoch 2/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.3020 - loss/inference_loss: 0.3020 - val_loss: 0.3505 - val_loss/inference_loss: 0.3505
Epoch 3/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2831 - loss/inference_loss: 0.2831 - val_loss: 0.2429 - val_loss/inference_loss: 0.2429
Epoch 4/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2813 - loss/inference_loss: 0.2813 - val_loss: 0.3600 - val_loss/inference_loss: 0.3600
Epoch 5/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2732 - loss/inference_loss: 0.2732 - val_loss: 0.2537 - val_loss/inference_loss: 0.2537
Epoch 6/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2722 - loss/inference_loss: 0.2722 - val_loss: 0.2904 - val_loss/inference_loss: 0.2904
Epoch 7/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2618 - loss/inference_loss: 0.2618 - val_loss: 0.1984 - val_loss/inference_loss: 0.1984
Epoch 8/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2663 - loss/inference_loss: 0.2663 - val_loss: 0.1680 - val_loss/inference_loss: 0.1680
Epoch 9/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2599 - loss/inference_loss: 0.2599 - val_loss: 0.2595 - val_loss/inference_loss: 0.2595
Epoch 10/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2567 - loss/inference_loss: 0.2567 - val_loss: 0.2612 - val_loss/inference_loss: 0.2612
Epoch 11/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2530 - loss/inference_loss: 0.2530 - val_loss: 0.2694 - val_loss/inference_loss: 0.2694
Epoch 12/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2596 - loss/inference_loss: 0.2596 - val_loss: 0.3073 - val_loss/inference_loss: 0.3073
Epoch 13/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2566 - loss/inference_loss: 0.2566 - val_loss: 0.1798 - val_loss/inference_loss: 0.1798
Epoch 14/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2589 - loss/inference_loss: 0.2589 - val_loss: 0.2743 - val_loss/inference_loss: 0.2743
Epoch 15/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2493 - loss/inference_loss: 0.2493 - val_loss: 0.2189 - val_loss/inference_loss: 0.2189
Epoch 16/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2649 - loss/inference_loss: 0.2649 - val_loss: 0.2154 - val_loss/inference_loss: 0.2154
Epoch 17/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2609 - loss/inference_loss: 0.2609 - val_loss: 0.2758 - val_loss/inference_loss: 0.2758
Epoch 18/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2526 - loss/inference_loss: 0.2526 - val_loss: 0.1542 - val_loss/inference_loss: 0.1542
Epoch 19/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2510 - loss/inference_loss: 0.2510 - val_loss: 0.1860 - val_loss/inference_loss: 0.1860
Epoch 20/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2559 - loss/inference_loss: 0.2559 - val_loss: 0.2213 - val_loss/inference_loss: 0.2213
Epoch 21/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2524 - loss/inference_loss: 0.2524 - val_loss: 0.2497 - val_loss/inference_loss: 0.2497
Epoch 22/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2477 - loss/inference_loss: 0.2477 - val_loss: 0.2030 - val_loss/inference_loss: 0.2030
Epoch 23/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2447 - loss/inference_loss: 0.2447 - val_loss: 0.2862 - val_loss/inference_loss: 0.2862
Epoch 24/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2451 - loss/inference_loss: 0.2451 - val_loss: 0.3859 - val_loss/inference_loss: 0.3859
Epoch 25/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2429 - loss/inference_loss: 0.2429 - val_loss: 0.2310 - val_loss/inference_loss: 0.2310
Epoch 26/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2437 - loss/inference_loss: 0.2437 - val_loss: 0.2236 - val_loss/inference_loss: 0.2236
Epoch 27/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2432 - loss/inference_loss: 0.2432 - val_loss: 0.3466 - val_loss/inference_loss: 0.3466
Epoch 28/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2353 - loss/inference_loss: 0.2353 - val_loss: 0.2234 - val_loss/inference_loss: 0.2234
Epoch 29/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2351 - loss/inference_loss: 0.2351 - val_loss: 0.1637 - val_loss/inference_loss: 0.1637
Epoch 30/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2364 - loss/inference_loss: 0.2364 - val_loss: 0.2324 - val_loss/inference_loss: 0.2324
Epoch 31/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2310 - loss/inference_loss: 0.2310 - val_loss: 0.1853 - val_loss/inference_loss: 0.1853
Epoch 32/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2287 - loss/inference_loss: 0.2287 - val_loss: 0.1234 - val_loss/inference_loss: 0.1234
Epoch 33/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2159 - loss/inference_loss: 0.2159 - val_loss: 0.1985 - val_loss/inference_loss: 0.1985
Epoch 34/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2228 - loss/inference_loss: 0.2228 - val_loss: 0.4063 - val_loss/inference_loss: 0.4063
Epoch 35/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2155 - loss/inference_loss: 0.2155 - val_loss: 0.2233 - val_loss/inference_loss: 0.2233
Epoch 36/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2257 - loss/inference_loss: 0.2257 - val_loss: 0.1208 - val_loss/inference_loss: 0.1208
Epoch 37/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2231 - loss/inference_loss: 0.2231 - val_loss: 0.0776 - val_loss/inference_loss: 0.0776
Epoch 38/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2193 - loss/inference_loss: 0.2193 - val_loss: 0.2310 - val_loss/inference_loss: 0.2310
Epoch 39/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2204 - loss/inference_loss: 0.2204 - val_loss: 0.1733 - val_loss/inference_loss: 0.1733
Epoch 40/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2219 - loss/inference_loss: 0.2219 - val_loss: 0.1291 - val_loss/inference_loss: 0.1291
Epoch 41/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2127 - loss/inference_loss: 0.2127 - val_loss: 0.1073 - val_loss/inference_loss: 0.1073
Epoch 42/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2257 - loss/inference_loss: 0.2257 - val_loss: 0.2174 - val_loss/inference_loss: 0.2174
Epoch 43/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2178 - loss/inference_loss: 0.2178 - val_loss: 0.2001 - val_loss/inference_loss: 0.2001
Epoch 44/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2150 - loss/inference_loss: 0.2150 - val_loss: 0.2282 - val_loss/inference_loss: 0.2282
Epoch 45/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2130 - loss/inference_loss: 0.2130 - val_loss: 0.1956 - val_loss/inference_loss: 0.1956
Epoch 46/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2046 - loss/inference_loss: 0.2046 - val_loss: 0.1937 - val_loss/inference_loss: 0.1937
Epoch 47/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2082 - loss/inference_loss: 0.2082 - val_loss: 0.2303 - val_loss/inference_loss: 0.2303
Epoch 48/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2067 - loss/inference_loss: 0.2067 - val_loss: 0.1300 - val_loss/inference_loss: 0.1300
Epoch 49/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.2201 - loss/inference_loss: 0.2201 - val_loss: 0.1432 - val_loss/inference_loss: 0.1432
Epoch 50/50
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.2097 - loss/inference_loss: 0.2097 - val_loss: 0.1791 - val_loss/inference_loss: 0.1791
2.6. Good ‘ol Coupling Flows#
Of course, BayesFlow also supports established coupling flow models with a variety of parameters, including the timeless affine and spline flows.
affine_flow = bf.networks.CouplingFlow(subnet="mlp")
spline_flow = bf.networks.CouplingFlow(subnet="mlp", transform="spline", depth=4)
epochs = 30 # coupling flows need less epochs than free-form methods
affine_flow_workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
inference_network=affine_flow,
)
spline_flow_workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
inference_network=spline_flow,
)
2.6.1. Coupling Flow Training#
history = affine_flow_workflow.fit_offline(
training_data,
epochs=epochs,
batch_size=batch_size,
validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 11s 6ms/step - loss: -1.2724 - loss/inference_loss: -1.2724 - val_loss: -2.2869 - val_loss/inference_loss: -2.2869
Epoch 2/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -2.4998 - loss/inference_loss: -2.4998 - val_loss: -2.1505 - val_loss/inference_loss: -2.1505
Epoch 3/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -2.6645 - loss/inference_loss: -2.6645 - val_loss: -2.6893 - val_loss/inference_loss: -2.6893
Epoch 4/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: -2.7494 - loss/inference_loss: -2.7494 - val_loss: -2.4783 - val_loss/inference_loss: -2.4783
Epoch 5/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -2.7852 - loss/inference_loss: -2.7852 - val_loss: -2.7409 - val_loss/inference_loss: -2.7409
Epoch 6/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: -2.8575 - loss/inference_loss: -2.8575 - val_loss: -2.6789 - val_loss/inference_loss: -2.6789
Epoch 7/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: -2.9127 - loss/inference_loss: -2.9127 - val_loss: -3.0435 - val_loss/inference_loss: -3.0435
Epoch 8/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -2.9346 - loss/inference_loss: -2.9346 - val_loss: -2.9628 - val_loss/inference_loss: -2.9628
Epoch 9/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -2.9371 - loss/inference_loss: -2.9371 - val_loss: -2.3899 - val_loss/inference_loss: -2.3899
Epoch 10/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: -2.9999 - loss/inference_loss: -2.9999 - val_loss: -2.9758 - val_loss/inference_loss: -2.9758
Epoch 11/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.0083 - loss/inference_loss: -3.0083 - val_loss: -2.5975 - val_loss/inference_loss: -2.5975
Epoch 12/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.0434 - loss/inference_loss: -3.0434 - val_loss: -3.0387 - val_loss/inference_loss: -3.0387
Epoch 13/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.0920 - loss/inference_loss: -3.0920 - val_loss: -2.3253 - val_loss/inference_loss: -2.3253
Epoch 14/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.1255 - loss/inference_loss: -3.1255 - val_loss: -3.1556 - val_loss/inference_loss: -3.1556
Epoch 15/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.1654 - loss/inference_loss: -3.1654 - val_loss: -2.8726 - val_loss/inference_loss: -2.8726
Epoch 16/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.1779 - loss/inference_loss: -3.1779 - val_loss: -3.1917 - val_loss/inference_loss: -3.1917
Epoch 17/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.2124 - loss/inference_loss: -3.2124 - val_loss: -2.1857 - val_loss/inference_loss: -2.1857
Epoch 18/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - loss: -3.2409 - loss/inference_loss: -3.2409 - val_loss: -2.9640 - val_loss/inference_loss: -2.9640
Epoch 19/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.2689 - loss/inference_loss: -3.2689 - val_loss: -2.7462 - val_loss/inference_loss: -2.7462
Epoch 20/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.2759 - loss/inference_loss: -3.2759 - val_loss: -2.9262 - val_loss/inference_loss: -2.9262
Epoch 21/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - loss: -3.3185 - loss/inference_loss: -3.3185 - val_loss: -2.1308 - val_loss/inference_loss: -2.1308
Epoch 22/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.3447 - loss/inference_loss: -3.3447 - val_loss: -2.6838 - val_loss/inference_loss: -2.6838
Epoch 23/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.3607 - loss/inference_loss: -3.3607 - val_loss: -3.4517 - val_loss/inference_loss: -3.4517
Epoch 24/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - loss: -3.3540 - loss/inference_loss: -3.3540 - val_loss: -3.0357 - val_loss/inference_loss: -3.0357
Epoch 25/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.3831 - loss/inference_loss: -3.3831 - val_loss: -3.3768 - val_loss/inference_loss: -3.3768
Epoch 26/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - loss: -3.4183 - loss/inference_loss: -3.4183 - val_loss: -3.1671 - val_loss/inference_loss: -3.1671
Epoch 27/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - loss: -3.4194 - loss/inference_loss: -3.4194 - val_loss: 9.2951 - val_loss/inference_loss: 9.2951
Epoch 28/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.3851 - loss/inference_loss: -3.3851 - val_loss: -2.9896 - val_loss/inference_loss: -2.9896
Epoch 29/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: -3.4251 - loss/inference_loss: -3.4251 - val_loss: -3.3175 - val_loss/inference_loss: -3.3175
Epoch 30/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - loss: -3.4083 - loss/inference_loss: -3.4083 - val_loss: -3.2599 - val_loss/inference_loss: -3.2599
history = spline_flow_workflow.fit_offline(
training_data,
epochs=epochs,
batch_size=batch_size,
validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 21s 11ms/step - loss: -1.0211 - loss/inference_loss: -1.0211 - val_loss: -1.5653 - val_loss/inference_loss: -1.5653
Epoch 2/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -2.2318 - loss/inference_loss: -2.2318 - val_loss: -2.4503 - val_loss/inference_loss: -2.4503
Epoch 3/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -2.5659 - loss/inference_loss: -2.5659 - val_loss: -2.6406 - val_loss/inference_loss: -2.6406
Epoch 4/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -2.7533 - loss/inference_loss: -2.7533 - val_loss: -2.6080 - val_loss/inference_loss: -2.6080
Epoch 5/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -2.8607 - loss/inference_loss: -2.8607 - val_loss: -2.8657 - val_loss/inference_loss: -2.8657
Epoch 6/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -2.9422 - loss/inference_loss: -2.9422 - val_loss: -2.3686 - val_loss/inference_loss: -2.3686
Epoch 7/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -2.9989 - loss/inference_loss: -2.9989 - val_loss: -2.9271 - val_loss/inference_loss: -2.9271
Epoch 8/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.0638 - loss/inference_loss: -3.0638 - val_loss: -3.0360 - val_loss/inference_loss: -3.0360
Epoch 9/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.0553 - loss/inference_loss: -3.0553 - val_loss: -3.2254 - val_loss/inference_loss: -3.2254
Epoch 10/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.1095 - loss/inference_loss: -3.1095 - val_loss: -3.0538 - val_loss/inference_loss: -3.0538
Epoch 11/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.2038 - loss/inference_loss: -3.2038 - val_loss: -3.1451 - val_loss/inference_loss: -3.1451
Epoch 12/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.2396 - loss/inference_loss: -3.2396 - val_loss: -3.2923 - val_loss/inference_loss: -3.2923
Epoch 13/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.2695 - loss/inference_loss: -3.2695 - val_loss: -2.7734 - val_loss/inference_loss: -2.7734
Epoch 14/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.2778 - loss/inference_loss: -3.2778 - val_loss: -3.3034 - val_loss/inference_loss: -3.3034
Epoch 15/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.3281 - loss/inference_loss: -3.3281 - val_loss: -2.5565 - val_loss/inference_loss: -2.5565
Epoch 16/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.3416 - loss/inference_loss: -3.3416 - val_loss: -3.1074 - val_loss/inference_loss: -3.1074
Epoch 17/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.3612 - loss/inference_loss: -3.3612 - val_loss: -3.4145 - val_loss/inference_loss: -3.4145
Epoch 18/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.3806 - loss/inference_loss: -3.3806 - val_loss: -3.3677 - val_loss/inference_loss: -3.3677
Epoch 19/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.4344 - loss/inference_loss: -3.4344 - val_loss: -2.7881 - val_loss/inference_loss: -2.7881
Epoch 20/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.4364 - loss/inference_loss: -3.4364 - val_loss: -3.2963 - val_loss/inference_loss: -3.2963
Epoch 21/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.4700 - loss/inference_loss: -3.4700 - val_loss: -2.8206 - val_loss/inference_loss: -2.8206
Epoch 22/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.4920 - loss/inference_loss: -3.4920 - val_loss: -3.0479 - val_loss/inference_loss: -3.0479
Epoch 23/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.5153 - loss/inference_loss: -3.5153 - val_loss: -3.0690 - val_loss/inference_loss: -3.0690
Epoch 24/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.5309 - loss/inference_loss: -3.5309 - val_loss: -2.9115 - val_loss/inference_loss: -2.9115
Epoch 25/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.5516 - loss/inference_loss: -3.5516 - val_loss: -3.0176 - val_loss/inference_loss: -3.0176
Epoch 26/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.5659 - loss/inference_loss: -3.5659 - val_loss: -3.0676 - val_loss/inference_loss: -3.0676
Epoch 27/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.5792 - loss/inference_loss: -3.5792 - val_loss: -3.3448 - val_loss/inference_loss: -3.3448
Epoch 28/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.5791 - loss/inference_loss: -3.5791 - val_loss: -3.1989 - val_loss/inference_loss: -3.1989
Epoch 29/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.5712 - loss/inference_loss: -3.5712 - val_loss: -3.0113 - val_loss/inference_loss: -3.0113
Epoch 30/30
512/512 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - loss: -3.4733 - loss/inference_loss: -3.4733 - val_loss: -3.4334 - val_loss/inference_loss: -3.4334
2.7. Validation#
2.7.1. Two Moons Posterior#
The two moons posterior at point \(x = (0, 0)\) should resemble two crescent shapes. Below, we plot the corresponding posterior samples and posterior density.
These results suggest that these generative networks can approximate the true posterior well. You can achieve an even better fit if you use online training, more epochs, or better optimizer hyperparameters.
# Set the number of posterior draws you want to get
num_samples = 3000
# Obtain samples from amortized posterior
conditions = {"x": np.array([[0.0, 0.0]]).astype("float32")}
# Prepare figure
f, axes = plt.subplots(1, 4, figsize=(15, 6))
# Obtain samples from the approximators (can also use the workflows' methods)
nets = [
flow_matching_workflow.approximator,
consistency_model_workflow.approximator,
affine_flow_workflow.approximator,
spline_flow_workflow.approximator
]
names = ["Flow Matching", "Consistency Model", "Affine Coupling Flow", "Spline Coupling Flow"]
colors = ["#153c7a", "#7a1515", "#157a2d", "#7a6f15"]
for ax, net, name, color in zip(axes, nets, names, colors):
# Obtain samples
samples = net.sample(conditions=conditions, num_datasets=1, num_samples=num_samples)["theta"]
# Plot samples
ax.scatter(samples[0, :, 0], samples[0, :, 1], color=color, alpha=0.75, s=0.5)
sns.despine(ax=ax)
ax.set_title(f"{name}", fontsize=16)
ax.grid(alpha=0.3)
ax.set_aspect("equal", adjustable="box")
ax.set_xlim([-0.5, 0.5])
ax.set_ylim([-0.5, 0.5])
ax.set_xlabel(r"$\theta_1$", fontsize=15)
ax.set_ylabel(r"$\theta_2$", fontsize=15)
f.tight_layout()

The posterior looks as we have expected in this case. However, in general, we do not know how the posterior is supposed to look like for any specific dataset. As such, we need diagnostics that validate the correctness of the inferred posterior. One such diagnostic is simulation-based calibration(SBC), which we can apply for free due to amortization. For more details on SBC and diagnostic plots, see:
Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. arXiv preprint.
Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and its applications in goodness-of-fit evaluation and multiple sample comparison. Statistics and Computing.
The practical SBC interpretation guide by Martin Modrák: https://hyunjimoon.github.io/SBC/articles/rank_visualizations.html
Check out the next tutorial for a detailed walkthrough of the workflow’s functionality.