8. Hierarchical Model Comparison for Cognitive Models#

Part 2: Hierarchical Model Comparison

by Lasse Elsemüller

Table of Contents

from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from scipy import stats

import bayesflow as bf

8.1. Introduction#

This is the second part of the tutorial series covering amortized model comparison with BayesFlow! The general workflow, the scenario and the cognitive models were introduced in Part 1 and are assumed to be known, so here we will focus on the new elements introduced when comparing hierarchical models.

In Part 1, we only conducted model comparison for a single participant at a time. Let us now consider all participants and their nested observations simultaneously in our model comparison!

8.2. Generative Model Definition#

To extend our MPT models to hierarchical ones, we need to introduce a superordinate level that encodes our assumptions about the relationships between individuals. We use the most popular hierarchical MPT framework, the latent-trait approach by Klauer (2010). Here, we replace our non-hierarchical Beta priors by a multivariate normal distribution, which allows us to model correlations between our parameters. We afterwards use the cumulative distribution function of the standard normal distribution, \(\Phi\), to transform from the real-line to probabilities. Let’s write out our new model components explicitly with \(m \in M\) denoting the participants:

\[\begin{split} \begin{align} \left[ \begin{array}{l} d_m' \\ g_m' \end{array} \right] &\sim \mathcal{N} \left( \left[\begin{array}{l} \mu_{d} \\ \mu_{g} \end{array} \right], \Sigma \right) \text{ for } m=1,\dots,M\\ d_m &= \Phi(d_m') \text{ for } m=1,\dots,M\\ g_m &= \Phi(g_m') \text{ for } m=1,\dots,M\\ \end{align} \end{split}\]

8.2.1. Hyperpriors and Priors#

We now have to define hyperpriors for the parameters of the multivariate normal prior distribution. For the covariance matrix \(\Sigma\), the latent-trait approach employs a scaled inverse Wishart distribution. The \(Q\) parameter controls the correlation between our parameters \(d\) and \(g\), while the variances are determined jointly with the scaling parameters \(\lambda\).

\[\begin{split} \begin{align} \mu_{d} &\sim \mathcal{N}(0, 0.25) \\ \mu_{g} &\sim \mathcal{N}(0, 0.25) \\ Q &\sim InvWishart(\mathbb{I}, 10)\\ \lambda_p &\sim \textrm{Uniform}(0, 3) \text{ for } p= d', g'\\ \Sigma &= \textrm{Diag}(\lambda_p) Q \textrm{Diag}(\lambda_p)\\ \end{align} \end{split}\]

Here, we choose our priors to reflect our belief that the hierarchical models should generate data patterns similar to their non-hierarchical counterparts. Remember that Bayesian model comparison penalizes predictive flexibility and expects you to encode your theoretical assumptions in all parts of your models. Therefore, using flat/very weakly informative priors as you may do in parameter estimation won’t give you the results you are looking for here.

Things can get a little confusing for these hierarchical model formulations, so let’s have a look at the role of our prior choices:

  • \(\mu_d\) and \(\mu_g\): A zero-centered normal distribution on the probit scale translates to participant mean values centered around 0.5 on the probability scale.

  • \(Q\): An inverse Wishart distribution with an identity scale matrix centers the expected correlations between \(d\) and \(g\) at 0, while the 10 degrees of freedom encode our belief that high correlations are rather unlikely (see below for a visualization).

  • \(\lambda\): We keep the values of this auxiliary scaling parameter rather low to limit the amount of variability that we introduce into our models.

Q = stats.invwishart.rvs(df=10, scale=np.identity(2), size=5000)
corrs = Q[:, 0, 1] / (np.sqrt(Q[:, 0, 0] * Q[:, 1, 1]))
f, ax = plt.subplots(1, 1, figsize=(6, 4))
sns.histplot(corrs, kde=True, color="#8f2727", alpha=0.9, ax=ax)
ax.set_title("Inverse Wishart Correlation Prior")
sns.despine(ax=ax)
../_images/3bfaae3966b7d3d7f7d7141152499544daf2c7224223b36025d372d83b6e26a7.png

We now follow the same steps as in Part 1 to translate our prior into code:

PARAM_NAMES = [r"$\mu_d$", r"$\mu_g$", r"$\Sigma_{00}$", r"$\Sigma_{01}$", r"$\Sigma_{10}$", r"$\Sigma_{11}$"]
RNG = np.random.default_rng(2023)
def hierarchical_prior_fun(rng=None):
    """Samples a random parameter configuration from the hierarchical prior distribution."""

    if rng is None:
        rng = np.random.default_rng()

    mu_d = rng.normal(0, 0.25)
    mu_g = rng.normal(0, 0.25)
    Q = stats.invwishart.rvs(df=10, scale=np.identity(2), random_state=rng)
    lambdas = rng.uniform(0, 3, size=2)
    sigma = np.matmul(np.matmul(np.diag(lambdas), Q), np.diag(lambdas))
    return np.concatenate([np.r_[mu_d, mu_g], sigma.flatten()])
prior = bf.simulation.Prior(prior_fun=hierarchical_prior_fun, param_names=PARAM_NAMES)
prior(batch_size=1)
{'prior_draws': array([[-0.27184395,  0.04654362,  0.63897366,  0.08509082,  0.08509082,
          0.29657829]]),
 'batchable_context': None,
 'non_batchable_context': None}

8.2.2. Creating the Simulators#

At this point, it is important to stress again our new definition of a data set: In Part 1, we analyzed participants separately, so each data set contained a single participant. With hierarchical models, we can now take all of our experimental data simultaneously into consideration, so each data set contains several participants with nested observations each. We could also have other hierarchical models, such as students nested into classes or employees nested into organizations, so we refer to the higher order units with the general term groups. For this tutorial, we consider a scenario with 50 participants performing 100 trials each.

We continue our known workflow by first specifying simulator functions, then creating our generative models with the GenerativeModel wrapper and finally combining them with the MultiGenerativeModel wrapper.

N_GROUPS = 50
N_OBS = 100
def hierarchical_mpt_simulator(theta, model, num_groups, num_obs, rng=None, *args):
    """Simulates data from a hierarchical 1HT or 2HT MPT model, assuming equal proportions of old and new stimuli.

    Parameters
    ----------
    theta      : np.ndarray of shape (num_parameters, )
        Contains draws from the prior distribution for each parameter.
    model      : str, either "1HT" or "2HT"
        Decides the model to generate data from.
    num_groups : int
        The number of groups (participants).
    num_obs    : int
        The number of observations (trials) per group.

    Returns
    -------
    X     : np.ndarray of shape (num_groups, num_obs, 2)
        The generated data set. Contains two columns:
            1. Stimulus type (0="new", 1="old")
            2. Response (0="new", 1="old")
    """

    if rng is None:
        rng = np.random.default_rng()

    obs_per_condition = int(np.ceil(num_obs / 2))

    mu_d, mu_g = theta[:2]
    sigma = np.reshape(theta[2:], (2, 2))

    # Draw vectors containing individual parameters and transform to probabilities
    params = rng.multivariate_normal([mu_d, mu_g], sigma, size=num_groups)
    d = stats.norm.cdf(params[:, 0])
    g = stats.norm.cdf(params[:, 1])

    # Compute category probabilities per model
    if model == "1HT":
        p_11 = d + (1 - d) * g
        p_10 = (1 - d) * (1 - g)
        p_01 = g
        p_00 = 1 - g

    if model == "2HT":
        p_11 = d + (1 - d) * g
        p_10 = (1 - d) * (1 - g)
        p_01 = (1 - d) * g
        p_00 = d + (1 - d) * (1 - g)

    # Assert that category probabilities sum to 1
    assert np.all(np.isclose((p_11 + p_10, p_01 + p_00), 1)), "Category probabilities do not sum to 1!"

    # Create vectors of stimulus types
    stims_single = np.repeat([[1, 0]], repeats=obs_per_condition, axis=1)  # For 1 participant
    stims_data_set = np.repeat(stims_single, repeats=num_groups, axis=0)  # For all participants

    # Simulate responses
    resp_1 = rng.binomial(n=1, p=p_11, size=(obs_per_condition, num_groups)).T
    resp_0 = rng.binomial(n=1, p=p_01, size=(obs_per_condition, num_groups)).T
    resp = np.concatenate((resp_1, resp_0), axis=1)

    # Create final data set
    data = np.stack((stims_data_set, resp), axis=2)

    return data
model_1ht = bf.simulation.GenerativeModel(
    prior=prior,
    simulator=partial(hierarchical_mpt_simulator, model="1HT", num_groups=N_GROUPS, num_obs=N_OBS),
    name="1HT",
    simulator_is_batched=False,
)

model_2ht = bf.simulation.GenerativeModel(
    prior=prior,
    simulator=partial(hierarchical_mpt_simulator, model="2HT", num_groups=N_GROUPS, num_obs=N_OBS),
    name="2HT",
    simulator_is_batched=False,
)
INFO:root:Performing 2 pilot runs with the 1HT model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 6)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 50, 100, 2)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:Performing 2 pilot runs with the 2HT model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 6)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 50, 100, 2)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.

We added the group dimension to our data sets, so we now have our data in a 4-dimensional format with the shape (number of data sets, number of groups/participants, number of observations, number of variables):

model_output = model_1ht(batch_size=5)
print("Shape of data batch:", model_output["sim_data"].shape)
print("First 3 rows of first 2 participants in first data set:")
print(model_output["sim_data"][0, :2, :3, :])
Shape of data batch: (5, 50, 100, 2)
First 3 rows of first 2 participants in first data set:
[[[1 1]
  [1 1]
  [1 1]]

 [[1 1]
  [1 1]
  [1 1]]]
meta_model = bf.simulation.MultiGenerativeModel([model_1ht, model_2ht])

8.2.3. Prior Predictive Checks#

The interplay between parameters on several levels adds more complexity to the behavior of hierarchical models. Therefore, prior predictive or pushfoward checks become even more crucial for inspecting whether the chosen parametrization matches one’s expectations. Our simulated data sets now contain NUM_GROUPS participants and NUM_OBS observations each. If we want to simulate 1000 participants as in Part 1, we now need only 20 simulations from a model, as each data set contains 50 participants. Afterwards, we calculate the hit rates and false alarm rates for each simulated participant as before and plot them.

# 1. Data simulation from each model
sim_pfcheck_1ht = model_1ht(batch_size=20)
sim_pfcheck_2ht = model_2ht(batch_size=20)
# 2. Summary statistics
def get_rates(sim_data):
    """Get the hit rate and false alarm rate per participant for each data set in a batch
    of hierarchical data sets simulating binary decision (recognition) tasks.
    Assumes first half of data to cover old items and second half to cover new items."""

    obs_per_condition = int(np.ceil(sim_data.shape[-2] / 2))
    hit_rates = np.mean(sim_data[..., :obs_per_condition, 1], axis=2)
    fa_rates = np.mean(sim_data[..., obs_per_condition:, 1], axis=2)

    return hit_rates, fa_rates


rates_1htm = get_rates(sim_pfcheck_1ht["sim_data"])
rates_2htm = get_rates(sim_pfcheck_2ht["sim_data"])
rates = [rates_1htm, rates_2htm]
# 3a. Plot rates across all data sets
fig = plt.figure(constrained_layout=True, figsize=(8, 6))
subfigs = fig.subfigures(nrows=2, ncols=1)
model_names = ["Hierarchical 1HT MPT Model", "Hierarchical 2HT MPT Model"]
num_bins = 20
bins = np.linspace(0.0, 1.0, num_bins + 1)

for row, subfig in enumerate(subfigs):
    subfig.suptitle(model_names[row], fontsize=18)
    axs = subfig.subplots(nrows=1, ncols=2)
    sns.histplot(rates[row][0].flatten(), bins=bins, kde=True, color="#8f2727", alpha=0.9, ax=axs[0]).set(
        title="Hit Rates"
    )
    sns.histplot(rates[row][1].flatten(), bins=bins, kde=True, color="#8f2727", alpha=0.9, ax=axs[1]).set(
        title="False Alarm Rates"
    )
sns.despine()
../_images/f204fa5f1d5f022f39f2ebce67c9bc699640351ef9e2e5bbd5b32c81e7c05695.png

If we plot our 1000 participants over all data sets, we see simular patterns as in Part 1.

8.3. Defining, Training & Validating the Neural Approximator#

We adapt the neural network architecture to this new symmetry by changing only a single part, the summary network. We now use a HierarchicalNetwork which we pass one summary network for each level. The majority of hierarchical models in cognitive modeling assume IID data on all levels, so we simply use one DeepSet network for each level. If we would have, for instance, temporal dependencies within each participant, we would exchange the first network to one that is specialized for processing time series data, such as a TimeSeriesTransformer.

The first summary network summarizes the information contained within the participants separately and passes the resulting embeddings to the second summary network. The second network then further compresses all participant embeddings to a single vector of learned summary statistics, so we equip it with a sufficient number of summary dimensions to avoid a bottleneck.

After these adjustments, all subsequent elements of the training and validation process stay the same as in Part 1.

summary_net = bf.summary_networks.HierarchicalNetwork([
    bf.networks.DeepSet(), 
    bf.networks.DeepSet(summary_dim=64)
])
inference_net = bf.inference_networks.PMPNetwork(num_models=2)
amortizer = bf.amortizers.AmortizedModelComparison(inference_net, summary_net)
trainer = bf.trainers.Trainer(amortizer=amortizer, generative_model=meta_model)
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.

Note: Online learning will be a bit slower than what we observed in Part 1 due to the slower data generation, but still worth it due to amortization. :-) In practice, you would train longer than just \(5\) epochs and may want to try offline learning as well.

losses = trainer.train_online(epochs=5, iterations_per_epoch=100, batch_size=64)
diag_plot = bf.diagnostics.plot_losses(train_losses=losses, moving_average=True, ma_window_fraction=0.05)
../_images/d1c5cf0eda69d23f8e314a88de20ececef5ff59ff3430d4a694c6c9aab29bcb0.png
# Generate some validation data in a list to avoid memory troubles during evaluation
sim_data = [trainer.configurator(meta_model(50)) for _ in range(20)]

# Get true indices and predicted PMPs from the trained network
sim_indices = np.concatenate([s["model_indices"] for s in sim_data])

# Estimate model probs in a loop
model_probs = np.concatenate([amortizer.posterior_probs(s) for s in sim_data])
cal_curves = bf.diagnostics.plot_calibration_curves(sim_indices, model_probs)
../_images/5e8e9a471ba963a30c05fa9f0c52f02237b5d42c28d46be2ed8e367a1bd825b3.png
fig = bf.diagnostics.plot_confusion_matrix(sim_indices, model_probs)
../_images/433689966b05ffb1fe75175cb49f132413a0ac5d077e90356809b6315187adfb.png

Our neural network quickly learned to discriminate between the two hierarchical models and shows excellent performance when validated on simulated data. The calibration curves look a bit shaky, but the marginal bin histograms tell us that this is due to the majority of the predicted probabilities being close to 0 or 1, leaving the middle (‘uncertain’) bins quite abandonded.

8.4. Network Application#

As in Part 1, we apply our trained model to a synthetic data set from the 2HT model. We again redefine our simulator with fixed random seeds for reproducible results:

prior_fixed = bf.simulation.Prior(
    prior_fun=partial(hierarchical_prior_fun, rng=np.random.default_rng(2023)), param_names=PARAM_NAMES
)
fake_data_generator = bf.simulation.GenerativeModel(
    prior=prior_fixed,
    simulator=partial(
        hierarchical_mpt_simulator, model="2HT", num_groups=N_GROUPS, num_obs=N_OBS, rng=np.random.default_rng(2023)
    ),
    skip_test=True,
    simulator_is_batched=False,
)

fake_data = fake_data_generator(batch_size=1)["sim_data"]
print(fake_data.shape)
(1, 50, 100, 2)

We can inspect our simulated data set by looking at hit and false alarm rates for each of our 50 participants. This is best done visually:

rates = get_rates(fake_data)
f, ax = plt.subplots(1, 2, figsize=(8, 3))
sns.histplot(rates[0].flatten(), bins=bins, kde=True, color="#8f2727", alpha=0.9, ax=ax[0]).set(title="Hit Rates")
sns.histplot(rates[1].flatten(), bins=bins, kde=True, color="#8f2727", alpha=0.9, ax=ax[1]).set(
    title="False Alarm Rates"
)
sns.despine()
../_images/f03a203ffaae34aeee3b6c00137b1bb72444cda5919e5ee77d11781efa7036be.png

As in part 1, our data set contains many participants with low false alarm rates, which are unlikely under the 1HT model. Let’s see the evidence contained in our simulated data:

embeddings = summary_net(fake_data)
preds = inference_net.posterior_probs(embeddings)[0]
preds
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.00586275, 0.9941373 ], dtype=float32)>
bayes_factor12 = preds[0] / preds[1]
bayes_factor12
<tf.Tensor: shape=(), dtype=float32, numpy=0.005897322>

Our Bayesian model comparison reveals clear evidence for the 2HT model, much more decisive than in Part 1. While the unambiguousness of the results depends, of course, on the model specification and the randomly simulated data, our model comparison now contains \(50\) times more data by considering all participants simultaneously!