12. Neural Ratio Estimation with BayesFlow and PyMC#

This notebook demonstrates Neural Ratio Estimation (NRE) — a simulation-based inference method that replaces the intractable likelihood with a learned log-density ratio.

12.1. The Core Idea#

Our implementation follows closely the paper Contrastive Neural Ratio Estimation for Simulation-Based Inference. The ratio estimator learns a single-observation log-ratio

\[\log r(\theta, x) = \log \frac{p(x \mid \theta)}{p(x)}\]

Because observations are i.i.d. given \(\theta\), the joint posterior factorizes:

\[\log p(\theta \mid x_{1:n}) = \underbrace{\sum_{i=1}^{n} \log r(\theta, x_i)}_{\text{sum of single-trial ratios}} + \log p(\theta) + \text{const}\]

This means we can:

  1. Train on individual simulated observations (cheap, parallelizeable)

  2. Aggregate at inference by summing \(n\) log-ratios — no retraining needed

Currently, integration with PyMC is only supported with the JAX backend.

12.2. Structure#

Part

Topic

1

Normal model — train, investigate, and run NUTS on a simple 2-parameter Gaussian

2

Regression — per-trial mean via a linear predictor

3

Drift-Diffusion Model (DDM) — realistic cognitive model with (RT, choice) data

import os
os.environ["KERAS_BACKEND"] = "jax"

import bayesflow as bf
import keras

import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

from scipy.stats import norm as sp_norm

import pymc as pm
import arviz as az

from bayesflow.wrappers.pymc import NeuralDistribution
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:bayesflow:Using backend 'jax'
INFO:arviz:Found 'auto' as default backend, checking available backends
INFO:arviz:Matplotlib is available, defining as default backend
INFO:arviz.preview:arviz_base available, exposing its functions as part of arviz.preview
INFO:arviz.preview:arviz_stats available, exposing its functions as part of arviz.preview
INFO:arviz.preview:arviz_plots available, exposing its functions as part of arviz.preview

12.3. Part 1: Normal Model#

12.3.1. Generative Model#

We use a 2-parameter Gaussian with one observation per simulation:

\[\begin{split} \begin{align} \mu &\sim \text{Uniform}(-5, 5)\\ \sigma &\sim \text{Uniform}(0.5, 3.0)\\ x \mid \mu, \sigma &\sim \mathcal{N}(\mu, \sigma^2) \end{align} \end{split}\]

The likelihood function produces exactly one draw \(x_i\) per call. At inference time the wrapper sums the per-trial log-ratios automatically.

def prior():
    mu = np.random.uniform(-5, 5)
    sigma = np.random.uniform(0.1, 5.0)
    return {"mu": mu, "sigma": sigma}

def likelihood(mu, sigma):
    return {"x": mu + sigma * np.random.standard_normal()}

simulator = bf.make_simulator([prior, likelihood])

# Sanity check: each simulation returns one x
batch = simulator.sample(4)
{k: v.shape for k, v in batch.items()}
{'mu': (4, 1), 'sigma': (4, 1), 'x': (4, 1)}

12.3.2. Build Adapter and Train#

inference_variables → the parameters \(\theta\); inference_conditions → the observation \(x\).

adapter = (
    bf.Adapter()
    .convert_dtype("float64", "float32")
    .concatenate(["mu", "sigma"], into="inference_variables")
    .rename("x", "inference_conditions")
)

ratio_approximator = bf.RatioApproximator(
    adapter=adapter,
    inference_network=bf.networks.MLP(),
    K=64,
    standardize=None,
)

learning_rate = keras.optimizers.schedules.CosineDecay(5e-4, decay_steps=5000)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
ratio_approximator.compile(optimizer=optimizer)

Training on a fast simulator is cheap on both GPU and CPU, expect ~ 30 seconds.

history = ratio_approximator.fit(
    simulator=simulator,
    epochs=100,
    num_batches=50,
    batch_size=128,
    verbose=2
)
f = bf.diagnostics.plots.loss(history)
../_images/3b193259094b026f5260cf986c686385663ce6b17fea0563156c03288c816697.png

12.4. The NeuralDistribution Wrapper#

NeuralDistribution is a thin bridge that turns a trained RatioApproximator or a ContinuousApproximator into a PyMC custom distribution usable with any gradient-based (NUTS, HMC) or gradient-free sampler (Slice).

Argument

Role

ratio_approximator

The trained estimator

param_names

Names of the parameters (must match inference_variables)

exchangeable=True

Assume i.i.d. observations → jax.vmap the network over observations

ratio_dist = NeuralDistribution(
    approximator=ratio_approximator,
    param_names=["mu", "sigma"],
    exchangeable=True
)

12.4.1. Log-Ratio Landscape#

Fix \(x_{obs} = 2.0\) and sweep \((\mu, \sigma)\). The learned log-ratio should match the analytic log-likelihood up to a constant offset — the log model evidence \(\log p(x_{obs})\), which is independent of \(\theta\).

x_obs_val = 2.0
mu_grid = np.linspace(-4.5, 4.5, 80)
sigma_grid = np.linspace(0.6, 2.9, 60)
MU, SIGMA = np.meshgrid(mu_grid, sigma_grid)

mu_flat   = jnp.array(MU.flatten())
sigma_flat = jnp.array(SIGMA.flatten())
x_flat    = jnp.full_like(mu_flat, x_obs_val)

vmap_fn_all = jax.vmap(ratio_dist.backend.log_prob, in_axes=(0, 0, 0))
LR = np.asarray(vmap_fn_all(x_flat, mu_flat, sigma_flat)).reshape(MU.shape)
LL = sp_norm.logpdf(x_obs_val, loc=MU, scale=SIGMA)

fig, axes = plt.subplots(1, 3, figsize=(16, 5))
for ax, Z, title, cmap in zip(
    axes,
    [LR, LL, LL - LR],
    ["Log Ratio (learned)", "Analytic Log Likelihood", r"LL $-$ LR $\approx$ const"],
    ["viridis", "viridis", "RdBu_r"],
):
    im = ax.pcolormesh(MU, SIGMA, Z, shading="auto", cmap=cmap)
    ax.set_xlabel(r"$\mu$"); ax.set_ylabel(r"$\sigma$")
    ax.set_title(f"{title}\n($x_{{obs}}={x_obs_val}$)")
    plt.colorbar(im, ax=ax)
    
plt.tight_layout()
../_images/f74363d22d868928be6aee43a702ae87641c94bdf2b7f781a3a09b82b5c2bdd9.png

12.4.2. Scalar Parameter Inference with NUTS#

Simulate \(n = 50\) observations from the true model and recover \((\mu, \sigma)\). Both parameters are scalars shared across all trials.

SEED = 42
RNG = np.random.default_rng(SEED)

mu_true, sigma_true, n_obs = 2.0, 1.5, 50
x_observed = RNG.normal(mu_true, sigma_true, n_obs).astype(np.float32)
with pm.Model() as nre_model:

    mu = pm.TruncatedNormal("mu", mu=0, sigma=5, lower=-5, upper=5)
    sigma = pm.TruncatedNormal("sigma", mu=1.5, sigma=1.5, lower=0.5, upper=3.0)
    obs = ratio_dist("obs", mu=mu, sigma=sigma, observed=x_observed)

    trace_nre = pm.sample(
        draws=1000, 
        tune=1000,
        nuts_sampler="pymc",
        chains=4, 
        cores=1,   # need to use spawn context on Linux to avoid fork()
        random_seed=SEED,
        initvals={"mu": 0.0, "sigma": 1.5},
    )
INFO:pymc.sampling.mcmc:Initializing NUTS using jitter+adapt_diag...
INFO:pymc.sampling.mcmc:Sequential sampling (4 chains in 1 job)
INFO:pymc.sampling.mcmc:NUTS: [mu, sigma]

INFO:pymc.sampling.mcmc:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 16 seconds.
ax = az.plot_trace(trace_nre, var_names=["mu", "sigma"])
plt.tight_layout()
../_images/0a0d64570cd37f89ef1d7c8d93e2e1bad3f6b45ffdd2083eb0283b802ecb165f.png
with pm.Model() as true_model:

    mu = pm.TruncatedNormal("mu", mu=0, sigma=5, lower=-5, upper=5)
    sigma = pm.TruncatedNormal("sigma", mu=1.5, sigma=1.5, lower=0.5, upper=3.0)
    obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=x_observed)

    trace_true = pm.sample(
        draws=1000, 
        tune=1000,
        nuts_sampler="pymc",
        chains=4, 
        cores=1,   # need to use spawn context on Linux to avoid fork()
        random_seed=42,
        initvals={"mu": 0.0, "sigma": 1.5},
    )
INFO:pymc.sampling.mcmc:Initializing NUTS using jitter+adapt_diag...
INFO:pymc.sampling.mcmc:Sequential sampling (4 chains in 1 job)
INFO:pymc.sampling.mcmc:NUTS: [mu, sigma]

INFO:pymc.sampling.mcmc:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
fig, axes = plt.subplots(1, 2, figsize=(11, 4))
for ax, name, true_val in zip(axes, ["mu", "sigma"], [mu_true, sigma_true]):

    samples_nre = trace_nre.posterior[name].values.flatten()
    samples_true = trace_true.posterior[name].values.flatten()

    ax.hist(samples_nre, bins=50, density=True, alpha=0.6, label="NUTS Posterior (NRE)", color='darkblue')
    ax.hist(samples_true, bins=50, density=True, alpha=0.9, label="NUTS Posterior (True)", color='gray')

    ax.axvline(true_val, color="black", ls="--", label="True")
    ax.set_xlabel(rf"${name}$"); ax.set_title(rf"Posterior for ${name}$")

    ax.legend(fontsize=9)

fig.tight_layout()
../_images/044f6f1fc3babfcd7103e580cc72d2298f5a3662bc75d84c59c8b3d1d0c3fbf0.png

12.5. Part 2: Regression - Per-Trial Parameters#

The same trained ratio estimator can be used for regressed parameters without any retraining. For example, when \(\mu_i = \beta_0 + \beta_1 \cdot c_i\) varies per trial, we can pass a per-trial vector mu_reg to ratio_dist. The wrapper vmaps over trials and each call uses the matching \(\mu_i\).

12.5.1. Data-Generating Process#

$$

(1)#\[\begin{align} \beta_0 &\sim \mathcal{U}(-4, 4)\\ \beta_1 &\sim \mathcal{U}(-4, 4)\\ \sigma &\sim \mathcal{U}(0.5, 3.0)\\ \mu_i &= \beta_0 + \beta_1 \cdot c_i, \quad c_i \in \{0, 1\}\\ x_i \mid \mu_i, \sigma &\sim \mathcal{N}(\mu_i, \sigma^2) \end{align}\]
beta_0_true, beta_1_true, sigma_reg_true = 1.0, 2.0, 1.0
n_obs_reg = 100

condition_data = RNG.choice([0, 1], size=n_obs_reg).astype(np.float32)
mu_per_trial   = beta_0_true + beta_1_true * condition_data
x_reg_observed = RNG.normal(mu_per_trial, sigma_reg_true).astype(np.float32)
with pm.Model() as regression_model:
    beta_0 = pm.TruncatedNormal("beta_0", mu=0, sigma=3, lower=-4, upper=4)
    beta_1 = pm.TruncatedNormal("beta_1", mu=0, sigma=3, lower=-4, upper=4)
    sigma  = pm.TruncatedNormal("sigma",  mu=1.5, sigma=1.5, lower=0.5, upper=3.0)

    # Per-trial mu — no model changes needed, just pass the vector
    mu_reg = beta_0 + beta_1 * condition_data

    obs = ratio_dist("obs", mu=mu_reg, sigma=sigma, observed=x_reg_observed)

    trace_reg = pm.sample(
        draws=1000, 
        tune=1000,
        nuts_sampler="pymc",
        chains=4, 
        cores=1,
        random_seed=SEED,
        initvals={"beta_0": 0.0, "beta_1": 0.0, "sigma": 1.5},
    )
INFO:pymc.sampling.mcmc:Initializing NUTS using jitter+adapt_diag...
INFO:pymc.sampling.mcmc:Sequential sampling (4 chains in 1 job)
INFO:pymc.sampling.mcmc:NUTS: [beta_0, beta_1, sigma]

INFO:pymc.sampling.mcmc:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 24 seconds.
ax = az.plot_trace(trace_reg, var_names=["beta_0", "beta_1", "sigma"])
plt.tight_layout()
../_images/2487195ef68893f8ef4cbf1d7bfeb754c41e367a590e47186b4ec54a70428f82.png

Now, let’s fit the same model with the true Gaussian likelihood.

with pm.Model() as regression_model_true:

    beta_0 = pm.TruncatedNormal("beta_0", mu=0, sigma=3, lower=-4, upper=4)
    beta_1 = pm.TruncatedNormal("beta_1", mu=0, sigma=3, lower=-4, upper=4)
    sigma  = pm.TruncatedNormal("sigma",  mu=1.5, sigma=1.5, lower=0.5, upper=3.0)

    mu_reg = beta_0 + beta_1 * condition_data

    obs = pm.Normal("obs", mu=mu_reg, sigma=sigma, observed=x_reg_observed)

    trace_reg_true = pm.sample(
        draws=1000, 
        tune=1000,
        nuts_sampler="pymc",
        chains=4, 
        cores=1,   # need to use spawn context on Linux to avoid fork()
        random_seed=SEED,
        initvals={"beta_0": 0.0, "beta_1": 0.0, "sigma": 1.5},
    )
INFO:pymc.sampling.mcmc:Initializing NUTS using jitter+adapt_diag...
INFO:pymc.sampling.mcmc:Sequential sampling (4 chains in 1 job)
INFO:pymc.sampling.mcmc:NUTS: [beta_0, beta_1, sigma]

INFO:pymc.sampling.mcmc:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

for ax, name, true_val in zip(
    axes,
    ["beta_0", "beta_1", "sigma"],
    [beta_0_true, beta_1_true, sigma_reg_true]
):
    samples = trace_reg.posterior[name].values.flatten()
    samples_true = trace_reg_true.posterior[name].values.flatten()

    ax.hist(samples, bins=50, density=True, alpha=0.6, label="NUTS Posterior (NRE)", color="darkblue")
    ax.hist(samples_true, bins=50, density=True, alpha=0.9, label="NUTS Posterior (True)", color="gray")

    ax.axvline(true_val, color="black", ls="--", label=f"True")
    ax.set_title(f"Posterior for {name}")
    ax.set_xlabel(name)
    ax.legend()

fig.tight_layout()
../_images/52013cabe1e84f779add4e18496e5c02187f336ee64c8f565a1f620884b6abbc.png

12.6. Part 3: Drift-Diffusion Model (DDM)#

12.6.1. The Model#

The Drift-Diffusion Model (Ratcliff, 1978) is the canonical model of two-alternative forced-choice (2AFC) decision-making. Evidence accumulates as a Wiener diffusion process with drift \(v\) until it hits one of two absorbing boundaries separated by \(a\). Each trial produces an observable pair:

\[(RT_i,\ \text{choice}_i), \quad \text{choice}_i \in \{-1, +1\}\]

Parameter

Symbol

Range

Role

Drift rate

\(v\)

\([-3, 3]\)

Speed and direction of evidence accumulation

Boundary separation

\(a\)

\([0.3, 2.5]\)

Speed-accuracy trade-off

Starting point

\(z\)

\([0.1, 0.9]\)

Prior bias toward a boundary

Non-decision time

\(t\)

\([0, 2]\)

Sensory + motor latency

The analytic Wiener first-passage time (WFPT) likelihood exists but can only be approximated numerically, allowing us to compare facotirzed NRE against a gold standard. For more complex model variants, the likelihood may not exist and NRE is among the only viable choices (along with amortized NPE).

12.6.2. Why NRE Works Here#

DDM trials are conditionally i.i.d. given \(\theta = (v, a, z, t)\):

\[ \log p(\theta \mid \text{RT}_{1:n},\,\text{choice}_{1:n}) = \sum_{i=1}^{n} \log r\!\left(\theta,\; RT_i,\, \text{choice}_i\right) + \log p(\theta) + \text{const} \]

We train on single trials and aggregate at inference — identical workflow to Parts 1 & 2, just with a 2-dimensional observation space \((RT, \text{choice})\) instead of a scalar \(x\).

The below workflow assumed that you have installed the following libraries:

  1. Sequential Sampling Model Simulators (SSMS) - a library for fast simulation of diffusion-like models.

  2. Hierarchical Sequential Sampling Models(HSSM) - a library for Bayesian (hierarchical) modeling with diffusion-like models.

from hssm.distribution_utils import make_distribution_for_supported_model
from ssms.basic_simulators.simulator import simulator as ssm_simulator
from ssms.config import model_config as ssms_model_config
ddm_cfg = ssms_model_config["ddm"]
param_names_ddm = ddm_cfg["params"]
param_lower = np.array(ddm_cfg["param_bounds"][0])
param_upper = np.array(ddm_cfg["param_bounds"][1])

print("DDM parameter bounds (from ssm-simulators config):")
for name, lo, hi in zip(param_names_ddm, param_lower, param_upper):
    print(f"  {name}:  [{lo:.2f}, {hi:.2f}]")
DDM parameter bounds (from ssm-simulators config):
  v:  [-3.00, 3.00]
  a:  [0.30, 2.50]
  z:  [0.10, 0.90]
  t:  [0.00, 2.00]

12.6.3. Simulator: One Trial per Call#

The likelihood function returns exactly one \((RT, \text{choice})\) pair. This is the unit the ratio estimator learns on; the wrapper converts it into a bayesflow-friendly generator.

def ddm_prior():
    return {
        name: np.random.uniform(lo, hi)
        for name, lo, hi in zip(param_names_ddm, param_lower, param_upper)
    }

def ddm_likelihood(v, a, z, t):
    """Simulate ONE (RT, choice) trial — the NRE building block."""
    result = ssm_simulator(
        theta={"v": v, "a": a, "z": z, "t": t},
        model="ddm",
        n_samples=1,
        delta_t=0.001,
    )
    obs = np.array([result["rts"][0, 0], result["choices"][0, 0]], dtype=np.float32)
    return {"obs": obs}

ddm_simulator = bf.make_simulator([ddm_prior, ddm_likelihood])

# Sanity check: one batch of 3
batch = ddm_simulator.sample(3)
print("Keys:", list(batch.keys()))
print("obs shape:", batch["obs"].shape)
Keys: ['v', 'a', 'z', 't', 'obs']
obs shape: (3, 2)

12.6.4. Train the DDM Ratio Estimator#

Each observation is a 2-vector \((RT_i, \text{choice}_i)\), so the network input dimension grows by 2 compared to the normal model. Everything else is identical.

adapter_ddm = bf.RatioApproximator.build_adapter(
    inference_variables=["v", "a", "z", "t"],
    inference_conditions=["obs"]
)

ratio_approximator_ddm = bf.RatioApproximator(
    adapter=adapter_ddm,
    inference_network=bf.networks.MLP(widths=(256,)*3),
    K=32,
    standardize=None
)

learning_rate = keras.optimizers.schedules.CosineDecay(5e-4, decay_steps=45_000)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
ratio_approximator_ddm.compile(optimizer=optimizer)

Training will take around 15 minutes on a GPU.

history_ddm = ratio_approximator_ddm.fit(
    simulator=ddm_simulator,
    epochs=300,
    num_batches=150,
    workers=4,
    verbose=2,
    batch_size=64
)
f = bf.diagnostics.plots.loss(history_ddm)
../_images/3419ec2df5a32713b956e5f66a9af2396da4fa90eefb4c6655f35d585d030791.png

12.6.5. NUTS Sampling for the DDM#

We generate \(n = 200\) (RT, choice) trials from known ground-truth parameters and try to recover them with NUTS.

v_true, a_true, z_true, t_true = 0.5, 1.5, 0.5, 0.3
n_obs_ddm = 200

result = ssm_simulator(
    theta={"v": v_true, "a": a_true, "z": z_true, "t": t_true},
    model="ddm", 
    n_samples=n_obs_ddm, 
    random_state=SEED, 
    delta_t=0.001
)

x_observed_ddm = np.c_[result["rts"], result["choices"]].squeeze().astype(np.float32)
ratio_dist_ddm = NeuralDistribution(
    approximator=ratio_approximator_ddm,
    param_names=["v", "a", "z", "t"],
    exchangeable=True
)
with pm.Model() as ddm_model:

    v = pm.TruncatedNormal("v", mu=0,   sigma=1.5, lower=-3.0, upper=3.0)
    a = pm.TruncatedNormal("a", mu=1.0, sigma=0.5, lower=0.3,  upper=2.5)
    z = pm.TruncatedNormal("z", mu=0.5, sigma=0.2, lower=0.1,  upper=0.9)
    t = pm.TruncatedNormal("t", mu=0.3, sigma=0.3, lower=0.0,  upper=2.0)

    obs = ratio_dist_ddm("obs", v=v, a=a, z=z, t=t, observed=x_observed_ddm)

    trace_ddm = pm.sample(
        draws=1000, 
        tune=1000,
        nuts_sampler="pymc",
        chains=4, 
        cores=1,  # use cores=4 in a "spawn context" (will fork on Linux and die)
        random_seed=SEED,
        initvals={"v": 0.0, "a": 1.0, "z": 0.5, "t": 0.3}
    )
INFO:pymc.sampling.mcmc:Initializing NUTS using jitter+adapt_diag...
INFO:pymc.sampling.mcmc:Sequential sampling (4 chains in 1 job)
INFO:pymc.sampling.mcmc:NUTS: [v, a, z, t]

INFO:pymc.sampling.mcmc:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 42 seconds.
az.plot_trace(trace_ddm, var_names=["v", "a", "z", "t"])
plt.tight_layout()
../_images/ae9dc0b707555785673dc92262819c0dd845b170aaeeac8b91ea602fc0f6663e.png

12.6.6. Comparison with the Analytic WFPT Likelihood#

HSSM provides the exact Wiener first-passage time density as a reference. We fit the same observed data with identical priors to see how closely the NRE posterior matches the gold standard.

AnalyticalDDM = make_distribution_for_supported_model(
    "ddm", loglik_kind="analytical", backend="pytensor",
)

with pm.Model() as analytical_model:
    v_a = pm.TruncatedNormal("v", mu=0,   sigma=1.5, lower=-3.0, upper=3.0)
    a_a = pm.TruncatedNormal("a", mu=1.0, sigma=0.5, lower=0.3,  upper=2.5)
    z_a = pm.TruncatedNormal("z", mu=0.5, sigma=0.2, lower=0.1,  upper=0.9)
    t_a = pm.TruncatedNormal("t", mu=0.3, sigma=0.3, lower=0.0,  upper=2.0)

    obs_a = AnalyticalDDM("obs", v=v_a, a=a_a, z=z_a, t=t_a, observed=x_observed_ddm)

    trace_analytical = pm.sample(
        draws=1000, 
        tune=1000,
        nuts_sampler="pymc",
        chains=4, 
        cores=1,
        random_seed=SEED,
        initvals={"v": 0.0, "a": 1.0, "z": 0.5, "t": 0.3},
    )
INFO:pymc.sampling.mcmc:Initializing NUTS using jitter+adapt_diag...
INFO:pymc.sampling.mcmc:Sequential sampling (4 chains in 1 job)
INFO:pymc.sampling.mcmc:NUTS: [v, a, z, t]

INFO:pymc.sampling.mcmc:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 13 seconds.
ERROR:pymc.stats.convergence:There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
true_vals = {"v": v_true, "a": a_true, "z": z_true, "t": t_true}

fig, axes = plt.subplots(1, 4, figsize=(18, 4))

for ax, name in zip(axes, ["v", "a", "z", "t"]):

    bf_samples = trace_ddm.posterior[name].values.flatten()
    an_samples = trace_analytical.posterior[name].values.flatten()

    ax.hist(bf_samples, bins=50, density=True, alpha=0.6, color="darkblue", label="NUTS (NRE)", edgecolor="none")
    ax.hist(an_samples, bins=50, density=True, alpha=0.9, color="gray",     label="NUTS (WFPT)",  edgecolor="none")

    ax.axvline(true_vals[name], color="black", ls="--", lw=2, label=f"True = {true_vals[name]}")
    ax.set_xlabel(f"${name}$")
    ax.set_title(f"Posterior of ${name}$")
    ax.legend(fontsize=7)

fig.suptitle("NRE vs. WFPT Posteriors", fontsize=13, y=1.02)
fig.tight_layout()
../_images/72bc54e113eb5edd8c55ac2b2d0e810c4ffad452eca754ae2d39931471a5ce3a.png