Source code for bayesflow.diagnostics.plots.calibration_histogram

from collections.abc import Sequence, Mapping

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from scipy.stats import binom

from bayesflow.utils import logging
from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots


[docs] def calibration_histogram( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, figsize: Sequence[float] = None, num_bins: int = 10, binomial_interval: float = 0.99, label_fontsize: int = 16, title_fontsize: int = 18, tick_fontsize: int = 12, color: str = "#132a70", num_col: int = None, num_row: int = None, ) -> plt.Figure: """Creates and plots publication-ready histograms of rank statistics for simulation-based calibration (SBC) checks according to [1]. Any deviation from uniformity indicates miscalibration and thus poor convergence of the networks or poor combination between generative model / networks. [1] Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. arXiv preprint arXiv:1804.06788. Parameters ---------- estimates : np.ndarray of shape (n_data_sets, n_post_draws, n_params) The posterior draws obtained from n_data_sets targets : np.ndarray of shape (n_data_sets, n_params) The prior draws obtained for generating n_data_sets variable_keys : list or None, optional, default: None Select keys from the dictionaries provided in estimates and targets. By default, select all keys. variable_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None figsize : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None num_bins : int, optional, default: 10 The number of bins to use for each marginal histogram binomial_interval : float in (0, 1), optional, default: 0.99 The width of the confidence interval for the binomial distribution label_fontsize : int, optional, default: 16 The font size of the y-label text title_fontsize : int, optional, default: 18 The font size of the title text tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels color : str, optional, default '#a34f4f' The color to use for the histogram body num_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation form the expected shapes of `estimates` and `targets`. """ plot_data = prepare_plot_data( estimates=estimates, targets=targets, variable_keys=variable_keys, variable_names=variable_names, num_col=num_col, num_row=num_row, figsize=figsize, ) estimates = plot_data.pop("estimates") targets = plot_data.pop("targets") # Determine the ratio of simulations to prior draw # num_params = plot_data['num_variables'] num_sims = estimates.shape[0] num_draws = estimates.shape[1] ratio = int(num_sims / num_draws) # Log a warning if N/B ratio recommended by Talts et al. (2018) < 20 if ratio < 20: logging.warning( "The ratio of simulations / posterior draws should be > 20 " f"for reliable variance reduction, but your ratio is {ratio}. " "Confidence intervals might be unreliable!" ) # Set num_bins automatically, if nothing provided if num_bins is None: num_bins = int(ratio / 2) # Attempt a fix if a single bin is determined so plot still makes sense if num_bins == 1: num_bins = 4 # Compute ranks (using broadcasting) ranks = np.sum(estimates < targets[:, np.newaxis, :], axis=1) # Compute confidence interval and mean num_trials = int(targets.shape[0]) # uniform distribution expected -> for all bins: equal probability # p = 1 / num_bins that a rank lands in that bin endpoints = binom.interval(binomial_interval, num_trials, 1 / num_bins) mean = num_trials / num_bins # corresponds to binom.mean(N, 1 / num_bins) for j, ax in enumerate(plot_data["axes"].flat): ax.axhspan(endpoints[0], endpoints[1], facecolor="gray", alpha=0.3) ax.axhline(mean, color="gray", zorder=0, alpha=0.9) sns.histplot(ranks[:, j], kde=False, ax=ax, color=color, bins=num_bins, alpha=0.95) ax.get_yaxis().set_ticks([]) prettify_subplots(plot_data["axes"], tick_fontsize) add_titles_and_labels( axes=plot_data["axes"], num_row=plot_data["num_row"], num_col=plot_data["num_col"], title=plot_data["variable_names"], xlabel="Rank statistic", ylabel="", title_fontsize=title_fontsize, label_fontsize=label_fontsize, ) plot_data["fig"].tight_layout() return plot_data["fig"]