Source code for bayesflow.diagnostics.plots.z_score_contraction

from collections.abc import Callable, Sequence, Mapping

import matplotlib.pyplot as plt
import numpy as np

from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
from bayesflow.utils.dict_utils import compute_test_quantities


[docs] def z_score_contraction( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, test_quantities: dict[str, Callable] = None, figsize: Sequence[int] = None, label_fontsize: int = 16, title_fontsize: int = 18, tick_fontsize: int = 12, color: str = "#132a70", num_col: int = None, num_row: int = None, markersize: float = None, ) -> plt.Figure: """ Implements a graphical check for global model sensitivity by plotting the posterior z-score over the posterior contraction for each set of posterior samples in ``estimates`` according to [1]. - The definition of the posterior z-score is: post_z_score = (posterior_mean - true_parameters) / posterior_std And the score is adequate if it centers around zero and spreads roughly in the interval [-3, 3] - The definition of posterior contraction is: post_contraction = 1 - (posterior_variance / prior_variance) In other words, the posterior contraction is a proxy for the reduction in uncertainty gained by replacing the prior with the posterior. The ideal posterior contraction tends to 1. Contraction near zero indicates that the posterior variance is almost identical to the prior variance for the particular marginal parameter distribution. Note: Means and variances will be estimated via their sample-based estimators. [1] Schad, D. J., Betancourt, M., & Vasishth, S. (2021). Toward a principled Bayesian workflow in cognitive science. Psychological methods, 26(1), 103. Paper also available at https://arxiv.org/abs/1904.12765 Parameters ---------- estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params) The posterior draws obtained from num_datasets targets : np.ndarray of shape (num_datasets, num_params) The prior draws (true parameters) used for generating the num_datasets variable_keys : list or None, optional, default: None Select keys from the dictionaries provided in estimates and targets. By default, select all keys. variable_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None test_quantities : dict or None, optional, default: None A dict that maps plot titles to functions that compute test quantities based on estimate/target draws. The dict keys are automatically added to ``variable_keys`` and ``variable_names``. Test quantity functions are expected to accept a dict of draws with shape ``(batch_size, ...)`` as the first (typically only) positional argument and return an NumPy array of shape ``(batch_size,)``. The functions do not have to deal with an additional sample dimension, as appropriate reshaping is done internally. figsize : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 16 The font size of the y-label text title_fontsize : int, optional, default: 18 The font size of the title text tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels color : str, optional, default: '#8f2727' The color for the true vs. estimated scatter points and error bars num_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. markersize : float, optional, default: None The marker size in points**2 of the scatter plot. Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation from the expected shapes of ``estimates`` and ``targets``. """ # Optionally, compute and prepend test quantities from draws if test_quantities is not None: updated_data = compute_test_quantities( targets=targets, estimates=estimates, variable_keys=variable_keys, variable_names=variable_names, test_quantities=test_quantities, ) variable_names = updated_data["variable_names"] variable_keys = updated_data["variable_keys"] estimates = updated_data["estimates"] targets = updated_data["targets"] # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=estimates, targets=targets, variable_keys=variable_keys, variable_names=variable_names, num_col=num_col, num_row=num_row, figsize=figsize, ) estimates = plot_data.pop("estimates") targets = plot_data.pop("targets") # Estimate posterior means and stds post_means = estimates.mean(axis=1) post_vars = estimates.var(axis=1, ddof=1) post_stds = np.sqrt(post_vars) # Estimate prior variance prior_vars = targets.var(axis=0, keepdims=True, ddof=1) # Compute contraction and z-score contraction = np.clip(1 - (post_vars / prior_vars), 0, 1) z_score = (post_means - targets) / post_stds # Loop and plot for i, ax in enumerate(plot_data["axes"].flat): if i >= plot_data["num_variables"]: break ax.scatter(contraction[:, i], z_score[:, i], color=color, alpha=0.5, s=markersize) ax.set_xlim([-0.05, 1.05]) prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize) # Add labels, titles, and set font sizes add_titles_and_labels( axes=plot_data["axes"], num_row=plot_data["num_row"], num_col=plot_data["num_col"], title=plot_data["variable_names"], xlabel="Posterior contraction", ylabel="Posterior z-score", title_fontsize=title_fontsize, label_fontsize=label_fontsize, ) plot_data["fig"].tight_layout() return plot_data["fig"]