Source code for bayesflow.diagnostics

# Copyright (c) 2022 The BayesFlow Developers

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
from scipy.stats import binom, median_abs_deviation
from sklearn.metrics import confusion_matrix, r2_score

logging.basicConfig()

from bayesflow.computational_utilities import expected_calibration_error, simultaneous_ecdf_bands
from bayesflow.helper_functions import check_posterior_prior_shapes


[docs] def plot_recovery( post_samples, prior_samples, point_agg=np.median, uncertainty_agg=median_abs_deviation, param_names=None, fig_size=None, label_fontsize=16, title_fontsize=18, metric_fontsize=16, tick_fontsize=12, add_corr=True, add_r2=True, color="#8f2727", n_col=None, n_row=None, xlabel="Ground truth", ylabel="Estimated", **kwargs, ): """Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty. The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate can be controlled with the ``uncertainty_agg`` argument. This plot yields similar information as the "posterior z-score", but allows for generic point and uncertainty estimates: https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html Important: Posterior aggregates play no special role in Bayesian inference and should only be used heuristically. For instance, in the case of multi-modal posteriors, common point estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing. Parameters ---------- post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) The posterior draws obtained from n_data_sets prior_samples : np.ndarray of shape (n_data_sets, n_params) The prior draws (true parameters) obtained for generating the n_data_sets point_agg : callable, optional, default: ``np.median`` The function to apply to the posterior draws to get a point estimate for each marginal. The default computes the marginal median for each marginal posterior as a robust point estimate. uncertainty_agg : callable or None, optional, default: scipy.stats.median_abs_deviation The function to apply to the posterior draws to get an uncertainty estimate. If ``None`` provided, a simple scatter using only ``point_agg`` will be plotted. param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None fig_size : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 16 The font size of the y-label text title_fontsize : int, optional, default: 18 The font size of the title text metric_fontsize : int, optional, default: 16 The font size of the goodness-of-fit metric (if provided) tick_fontsize : int, optional, default: 12 The font size of the axis tick labels add_corr : bool, optional, default: True A flag for adding correlation between true and estimates to the plot add_r2 : bool, optional, default: True A flag for adding R^2 between true and estimates to the plot color : str, optional, default: '#8f2727' The color for the true vs. estimated scatter points and error bars n_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. n_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. xlabel : str, optional, default: 'Ground truth' The label on the x-axis of the plot ylabel : str, optional, default: 'Estimated' The label on the y-axis of the plot **kwargs : optional Additional keyword arguments passed to ax.errorbar or ax.scatter. Example: `rasterized=True` to reduce PDF file size with many dots Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation from the expected shapes of ``post_samples`` and ``prior_samples``. """ # Sanity check check_posterior_prior_shapes(post_samples, prior_samples) # Compute point estimates and uncertainties est = point_agg(post_samples, axis=1) if uncertainty_agg is not None: u = uncertainty_agg(post_samples, axis=1) # Determine n params and param names if None given n_params = prior_samples.shape[-1] if param_names is None: param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] # Determine number of rows and columns for subplots based on inputs if n_row is None and n_col is None: n_row = int(np.ceil(n_params / 6)) n_col = int(np.ceil(n_params / n_row)) elif n_row is None and n_col is not None: n_row = int(np.ceil(n_params / n_col)) elif n_row is not None and n_col is None: n_col = int(np.ceil(n_params / n_row)) # Initialize figure if fig_size is None: fig_size = (int(4 * n_col), int(4 * n_row)) f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) # turn axarr into 1D list axarr = np.atleast_1d(axarr) if n_col > 1 or n_row > 1: axarr_it = axarr.flat else: axarr_it = axarr for i, ax in enumerate(axarr_it): if i >= n_params: break # Add scatter and error bars if uncertainty_agg is not None: _ = ax.errorbar(prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color, **kwargs) else: _ = ax.scatter(prior_samples[:, i], est[:, i], alpha=0.5, color=color, **kwargs) # Make plots quadratic to avoid visual illusions lower = min(prior_samples[:, i].min(), est[:, i].min()) upper = max(prior_samples[:, i].max(), est[:, i].max()) eps = (upper - lower) * 0.1 ax.set_xlim([lower - eps, upper + eps]) ax.set_ylim([lower - eps, upper + eps]) ax.plot( [ax.get_xlim()[0], ax.get_xlim()[1]], [ax.get_ylim()[0], ax.get_ylim()[1]], color="black", alpha=0.9, linestyle="dashed", ) # Add optional metrics and title if add_r2: r2 = r2_score(prior_samples[:, i], est[:, i]) ax.text( 0.1, 0.9, "$R^2$ = {:.3f}".format(r2), horizontalalignment="left", verticalalignment="center", transform=ax.transAxes, size=metric_fontsize, ) if add_corr: corr = np.corrcoef(prior_samples[:, i], est[:, i])[0, 1] ax.text( 0.1, 0.8, "$r$ = {:.3f}".format(corr), horizontalalignment="left", verticalalignment="center", transform=ax.transAxes, size=metric_fontsize, ) ax.set_title(param_names[i], fontsize=title_fontsize) # Prettify sns.despine(ax=ax) ax.grid(alpha=0.5) ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) # Only add x-labels to the bottom row bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] for _ax in bottom_row: _ax.set_xlabel(xlabel, fontsize=label_fontsize) # Only add y-labels to right left-most row if n_row == 1: # if there is only one row, the ax array is 1D axarr[0].set_ylabel(ylabel, fontsize=label_fontsize) # If there is more than one row, the ax array is 2D else: for _ax in axarr[:, 0]: _ax.set_ylabel(ylabel, fontsize=label_fontsize) # Remove unused axes entirely for _ax in axarr_it[n_params:]: _ax.remove() f.tight_layout() return f
[docs] def plot_z_score_contraction( post_samples, prior_samples, param_names=None, fig_size=None, label_fontsize=16, title_fontsize=18, tick_fontsize=12, color="#8f2727", n_col=None, n_row=None, ): """Implements a graphical check for global model sensitivity by plotting the posterior z-score over the posterior contraction for each set of posterior samples in ``post_samples`` according to [1]. - The definition of the posterior z-score is: post_z_score = (posterior_mean - true_parameters) / posterior_std And the score is adequate if it centers around zero and spreads roughly in the interval [-3, 3] - The definition of posterior contraction is: post_contraction = 1 - (posterior_variance / prior_variance) In other words, the posterior contraction is a proxy for the reduction in uncertainty gained by replacing the prior with the posterior. The ideal posterior contraction tends to 1. Contraction near zero indicates that the posterior variance is almost identical to the prior variance for the particular marginal parameter distribution. Note: Means and variances will be estimated via their sample-based estimators. [1] Schad, D. J., Betancourt, M., & Vasishth, S. (2021). Toward a principled Bayesian workflow in cognitive science. Psychological methods, 26(1), 103. Paper also available at https://arxiv.org/abs/1904.12765 Parameters ---------- post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) The posterior draws obtained from n_data_sets prior_samples : np.ndarray of shape (n_data_sets, n_params) The prior draws (true parameters) obtained for generating the n_data_sets param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None fig_size : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 16 The font size of the y-label text title_fontsize : int, optional, default: 18 The font size of the title text tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels color : str, optional, default: '#8f2727' The color for the true vs. estimated scatter points and error bars n_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. n_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation from the expected shapes of ``post_samples`` and ``prior_samples``. """ # Sanity check for shape integrity check_posterior_prior_shapes(post_samples, prior_samples) # Estimate posterior means and stds post_means = post_samples.mean(axis=1) post_stds = post_samples.std(axis=1, ddof=1) post_vars = post_samples.var(axis=1, ddof=1) # Estimate prior variance prior_vars = prior_samples.var(axis=0, keepdims=True, ddof=1) # Compute contraction post_cont = 1 - (post_vars / prior_vars) # Compute posterior z score z_score = (post_means - prior_samples) / post_stds # Determine number of params and param names if None given n_params = prior_samples.shape[-1] if param_names is None: param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] # Determine number of rows and columns for subplots based on inputs if n_row is None and n_col is None: n_row = int(np.ceil(n_params / 6)) n_col = int(np.ceil(n_params / n_row)) elif n_row is None and n_col is not None: n_row = int(np.ceil(n_params / n_col)) elif n_row is not None and n_col is None: n_col = int(np.ceil(n_params / n_row)) # Initialize figure if fig_size is None: fig_size = (int(4 * n_col), int(4 * n_row)) f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) # turn axarr into 1D list axarr = np.atleast_1d(axarr) if n_col > 1 or n_row > 1: axarr_it = axarr.flat else: axarr_it = axarr # Loop and plot for i, ax in enumerate(axarr_it): if i >= n_params: break ax.scatter(post_cont[:, i], z_score[:, i], color=color, alpha=0.5) ax.set_title(param_names[i], fontsize=title_fontsize) sns.despine(ax=ax) ax.grid(alpha=0.5) ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) ax.set_xlim([-0.05, 1.05]) # Only add x-labels to the bottom row bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] for _ax in bottom_row: _ax.set_xlabel("Posterior contraction", fontsize=label_fontsize) # Only add y-labels to right left-most row if n_row == 1: # if there is only one row, the ax array is 1D axarr[0].set_ylabel("Posterior z-score", fontsize=label_fontsize) # If there is more than one row, the ax array is 2D else: for _ax in axarr[:, 0]: _ax.set_ylabel("Posterior z-score", fontsize=label_fontsize) # Remove unused axes entirely for _ax in axarr_it[n_params:]: _ax.remove() f.tight_layout() return f
[docs] def plot_sbc_ecdf( post_samples, prior_samples, difference=False, stacked=False, fig_size=None, param_names=None, label_fontsize=16, legend_fontsize=14, title_fontsize=18, tick_fontsize=12, rank_ecdf_color="#a34f4f", fill_color="grey", n_row=None, n_col=None, **kwargs, ): """Creates the empirical CDFs for each marginal rank distribution and plots it against a uniform ECDF. ECDF simultaneous bands are drawn using simulations from the uniform, as proposed by [1]. For models with many parameters, use `stacked=True` to obtain an idea of the overall calibration of a posterior approximator. [1] Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and its applications in goodness-of-fit evaluation and multiple sample comparison. Statistics and Computing, 32(2), 1-21. https://arxiv.org/abs/2103.10522 Parameters ---------- post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) The posterior draws obtained from n_data_sets prior_samples : np.ndarray of shape (n_data_sets, n_params) The prior draws obtained for generating n_data_sets difference : bool, optional, default: False If `True`, plots the ECDF difference. Enables a more dynamic visualization range. stacked : bool, optional, default: False If `True`, all ECDFs will be plotted on the same plot. If `False`, each ECDF will have its own subplot, similar to the behavior of `plot_sbc_histograms`. param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None. Only relevant if `stacked=False`. fig_size : tuple or None, optional, default: None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 16 The font size of the y-label and y-label texts legend_fontsize : int, optional, default: 14 The font size of the legend text title_fontsize : int, optional, default: 18 The font size of the title text. Only relevant if `stacked=False` tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels rank_ecdf_color : str, optional, default: '#a34f4f' The color to use for the rank ECDFs fill_color : str, optional, default: 'grey' The color of the fill arguments. n_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. n_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. **kwargs : dict, optional, default: {} Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation through the ``ecdf_bands_kwargs`` dictionary. See `simultaneous_ecdf_bands` for keyword arguments Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. """ # Sanity checks check_posterior_prior_shapes(post_samples, prior_samples) # Store reference to number of parameters n_params = post_samples.shape[-1] # Compute fractional ranks (using broadcasting) ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) / post_samples.shape[1] # Prepare figure if stacked: n_row, n_col = 1, 1 f, ax = plt.subplots(1, 1, figsize=fig_size) else: # Determine number of rows and columns for subplots based on inputs if n_row is None and n_col is None: n_row = int(np.ceil(n_params / 6)) n_col = int(np.ceil(n_params / n_row)) elif n_row is None and n_col is not None: n_row = int(np.ceil(n_params / n_col)) elif n_row is not None and n_col is None: n_col = int(np.ceil(n_params / n_row)) # Determine fig_size dynamically, if None if fig_size is None: fig_size = (int(5 * n_col), int(5 * n_row)) # Initialize figure f, ax = plt.subplots(n_row, n_col, figsize=fig_size) ax = np.atleast_1d(ax) # Plot individual ecdf of parameters for j in range(ranks.shape[-1]): ecdf_single = np.sort(ranks[:, j]) xx = ecdf_single yy = np.arange(1, xx.shape[-1] + 1) / float(xx.shape[-1]) # Difference, if specified if difference: yy -= xx if stacked: if j == 0: ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs") else: ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95) else: ax.flat[j].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDF") # Compute uniform ECDF and bands alpha, z, L, H = simultaneous_ecdf_bands(post_samples.shape[0], **kwargs.pop("ecdf_bands_kwargs", {})) # Difference, if specified if difference: L -= z H -= z ylab = "ECDF difference" else: ylab = "ECDF" # Add simultaneous bounds if stacked: titles = [None] axes = [ax] else: axes = ax.flat if param_names is None: titles = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] else: titles = param_names for _ax, title in zip(axes, titles): _ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands") # Prettify plot sns.despine(ax=_ax) _ax.grid(alpha=0.35) _ax.legend(fontsize=legend_fontsize) _ax.set_title(title, fontsize=title_fontsize) _ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) _ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) # Only add x-labels to the bottom row if stacked: bottom_row = [ax] else: bottom_row = ax if n_row == 1 else ax[-1, :] for _ax in bottom_row: _ax.set_xlabel("Fractional rank statistic", fontsize=label_fontsize) # Only add y-labels to right left-most row if n_row == 1: # if there is only one row, the ax array is 1D axes[0].set_ylabel(ylab, fontsize=label_fontsize) else: # if there is more than one row, the ax array is 2D for _ax in ax[:, 0]: _ax.set_ylabel(ylab, fontsize=label_fontsize) # Remove unused axes entirely for _ax in axes[n_params:]: _ax.remove() f.tight_layout() return f
[docs] def plot_sbc_histograms( post_samples, prior_samples, param_names=None, fig_size=None, num_bins=None, binomial_interval=0.99, label_fontsize=16, title_fontsize=18, tick_fontsize=12, hist_color="#a34f4f", n_row=None, n_col=None, ): """Creates and plots publication-ready histograms of rank statistics for simulation-based calibration (SBC) checks according to [1]. Any deviation from uniformity indicates miscalibration and thus poor convergence of the networks or poor combination between generative model / networks. [1] Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. arXiv preprint arXiv:1804.06788. Parameters ---------- post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) The posterior draws obtained from n_data_sets prior_samples : np.ndarray of shape (n_data_sets, n_params) The prior draws obtained for generating n_data_sets param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None fig_size : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None num_bins : int, optional, default: 10 The number of bins to use for each marginal histogram binomial_interval : float in (0, 1), optional, default: 0.99 The width of the confidence interval for the binomial distribution label_fontsize : int, optional, default: 16 The font size of the y-label text title_fontsize : int, optional, default: 18 The font size of the title text tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels hist_color : str, optional, default '#a34f4f' The color to use for the histogram body n_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. n_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. """ # Sanity check check_posterior_prior_shapes(post_samples, prior_samples) # Determine the ratio of simulations to prior draws n_sim, n_draws, n_params = post_samples.shape ratio = int(n_sim / n_draws) # Log a warning if N/B ratio recommended by Talts et al. (2018) < 20 if ratio < 20: logger = logging.getLogger() logger.setLevel(logging.INFO) logger.info( f"The ratio of simulations / posterior draws should be > 20 " + f"for reliable variance reduction, but your ratio is {ratio}.\ Confidence intervals might be unreliable!" ) # Set n_bins automatically, if nothing provided if num_bins is None: num_bins = int(ratio / 2) # Attempt a fix if a single bin is determined so plot still makes sense if num_bins == 1: num_bins = 5 # Determine n params and param names if None given if param_names is None: param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] # Determine number of rows and columns for subplots based on inputs if n_row is None and n_col is None: n_row = int(np.ceil(n_params / 6)) n_col = int(np.ceil(n_params / n_row)) elif n_row is None and n_col is not None: n_row = int(np.ceil(n_params / n_col)) elif n_row is not None and n_col is None: n_col = int(np.ceil(n_params / n_row)) # Initialize figure if fig_size is None: fig_size = (int(5 * n_col), int(5 * n_row)) f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) axarr = np.atleast_1d(axarr) # Compute ranks (using broadcasting) ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) # Compute confidence interval and mean N = int(prior_samples.shape[0]) # uniform distribution expected -> for all bins: equal probability # p = 1 / num_bins that a rank lands in that bin endpoints = binom.interval(binomial_interval, N, 1 / num_bins) mean = N / num_bins # corresponds to binom.mean(N, 1 / num_bins) # Plot marginal histograms in a loop if n_row > 1: ax = axarr.flat else: ax = axarr for j in range(len(param_names)): ax[j].axhspan(endpoints[0], endpoints[1], facecolor="gray", alpha=0.3) ax[j].axhline(mean, color="gray", zorder=0, alpha=0.9) sns.histplot(ranks[:, j], kde=False, ax=ax[j], color=hist_color, bins=num_bins, alpha=0.95) ax[j].set_title(param_names[j], fontsize=title_fontsize) ax[j].spines["right"].set_visible(False) ax[j].spines["top"].set_visible(False) ax[j].get_yaxis().set_ticks([]) ax[j].set_ylabel("") ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize) ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) # Only add x-labels to the bottom row bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] for _ax in bottom_row: _ax.set_xlabel("Rank statistic", fontsize=label_fontsize) # Remove unused axes entirely for _ax in axarr[n_params:]: _ax.remove() f.tight_layout() return f
[docs] def plot_posterior_2d( posterior_draws, prior=None, prior_draws=None, param_names=None, height=3, label_fontsize=14, legend_fontsize=16, tick_fontsize=12, post_color="#8f2727", prior_color="gray", post_alpha=0.9, prior_alpha=0.7, ): """Generates a bivariate pairplot given posterior draws and optional prior or prior draws. posterior_draws : np.ndarray of shape (n_post_draws, n_params) The posterior draws obtained for a SINGLE observed data set. prior : bayesflow.forward_inference.Prior instance or None, optional, default: None The optional prior object having an input-output signature as given by ayesflow.forward_inference.Prior prior_draws : np.ndarray of shape (n_prior_draws, n_params) or None, optonal (default: None) The optional prior draws obtained from the prior. If both prior and prior_draws are provided, prior_draws will be used. param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None height : float, optional, default: 3 The height of the pairplot label_fontsize : int, optional, default: 14 The font size of the x and y-label texts (parameter names) legend_fontsize : int, optional, default: 16 The font size of the legend text tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels post_color : str, optional, default: '#8f2727' The color for the posterior histograms and KDEs priors_color : str, optional, default: gray The color for the optional prior histograms and KDEs post_alpha : float in [0, 1], optonal, default: 0.9 The opacity of the posterior plots prior_alpha : float in [0, 1], optonal, default: 0.7 The opacity of the prior plots Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ AssertionError If the shape of posterior_draws is not 2-dimensional. """ # Ensure correct shape assert ( len(posterior_draws.shape) ) == 2, "Shape of `posterior_samples` for a single data set should be 2 dimensional!" # Obtain n_draws and n_params n_draws, n_params = posterior_draws.shape # If prior object is given and no draws, obtain draws if prior is not None and prior_draws is None: draws = prior(n_draws) if type(draws) is dict: prior_draws = draws["prior_draws"] else: prior_draws = draws # Otherwise, keep as is (prior_draws either filled or None) else: pass # Attempt to determine parameter names if param_names is None: if hasattr(prior, "param_names"): if prior.param_names is not None: param_names = prior.param_names else: param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] else: param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] # Pack posterior draws into a dataframe posterior_draws_df = pd.DataFrame(posterior_draws, columns=param_names) # Add posterior g = sns.PairGrid(posterior_draws_df, height=height) g.map_diag(sns.histplot, fill=True, color=post_color, alpha=post_alpha, kde=True) g.map_lower(sns.kdeplot, fill=True, color=post_color, alpha=post_alpha) # Add prior, if given if prior_draws is not None: prior_draws_df = pd.DataFrame(prior_draws, columns=param_names) g.data = prior_draws_df g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1) g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1) # Add legend, if prior also given if prior_draws is not None or prior is not None: handles = [ Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha), Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha), ] g.fig.legend(handles, ["Posterior", "Prior"], fontsize=legend_fontsize, loc="center right") # Remove upper axis for i, j in zip(*np.triu_indices_from(g.axes, 1)): g.axes[i, j].axis("off") # Modify tick sizes for i, j in zip(*np.tril_indices_from(g.axes, 1)): g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) # Add nice labels for i, param_name in enumerate(param_names): g.axes[i, 0].set_ylabel(param_name, fontsize=label_fontsize) g.axes[len(param_names) - 1, i].set_xlabel(param_name, fontsize=label_fontsize) # Add grids for i in range(n_params): for j in range(n_params): g.axes[i, j].grid(alpha=0.5) g.tight_layout() return g.fig
[docs] def plot_losses( train_losses, val_losses=None, moving_average=False, ma_window_fraction=0.01, fig_size=None, train_color="#8f2727", val_color="black", lw_train=2, lw_val=3, grid_alpha=0.5, legend_fontsize=14, label_fontsize=14, title_fontsize=16, ): """A generic helper function to plot the losses of a series of training epochs and runs. Parameters ---------- train_losses : pd.DataFrame The (plottable) history as returned by a train_[...] method of a ``Trainer`` instance. Alternatively, you can just pass a data frame of validation losses instead of train losses, if you only want to plot the validation loss. val_losses : pd.DataFrame or None, optional, default: None The (plottable) validation history as returned by a train_[...] method of a ``Trainer`` instance. If left ``None``, only train losses are plotted. Should have the same number of columns as ``train_losses``. moving_average : bool, optional, default: False A flag for adding a moving average line of the train_losses. ma_window_fraction : int, optional, default: 0.01 Window size for the moving average as a fraction of total training steps. fig_size : tuple or None, optional, default: None The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` train_color : str, optional, default: '#8f2727' The color for the train loss trajectory val_color : str, optional, default: black The color for the optional validation loss trajectory lw_train : int, optional, default: 2 The linewidth for the training loss curve lw_val : int, optional, default: 3 The linewidth for the validation loss curve grid_alpha : float, optional, default 0.5 The opacity factor for the background gridlines legend_fontsize : int, optional, default: 14 The font size of the legend text label_fontsize : int, optional, default: 14 The font size of the y-label text title_fontsize : int, optional, default: 16 The font size of the title text Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ AssertionError If the number of columns in ``train_losses`` does not match the number of columns in ``val_losses``. """ # Determine the number of rows for plot n_row = len(train_losses.columns) # Initialize figure if fig_size is None: fig_size = (16, int(4 * n_row)) f, axarr = plt.subplots(n_row, 1, figsize=fig_size) # Get the number of steps as an array train_step_index = np.arange(1, len(train_losses) + 1) if val_losses is not None: val_step = int(np.floor(len(train_losses) / len(val_losses))) val_step_index = train_step_index[(val_step - 1) :: val_step] # If unequal length due to some reason, attempt a fix if val_step_index.shape[0] > val_losses.shape[0]: val_step_index = val_step_index[: val_losses.shape[0]] # Loop through loss entries and populate plot looper = [axarr] if n_row == 1 else axarr.flat for i, ax in enumerate(looper): # Plot train curve ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") if moving_average and train_losses.columns[i] == "Loss": moving_average_window = int(train_losses.shape[0] * ma_window_fraction) smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean() ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") # Plot optional val curve if val_losses is not None: if i < val_losses.shape[1]: ax.plot( val_step_index, val_losses.iloc[:, i], linestyle="--", marker="o", color=val_color, lw=lw_val, label="Validation", ) # Schmuck ax.set_xlabel("Training step #", fontsize=label_fontsize) ax.set_ylabel("Value", fontsize=label_fontsize) sns.despine(ax=ax) ax.grid(alpha=grid_alpha) ax.set_title(train_losses.columns[i], fontsize=title_fontsize) # Only add legend if there is a validation curve if val_losses is not None or moving_average: ax.legend(fontsize=legend_fontsize) f.tight_layout() return f
[docs] def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f2727", **kwargs): """Creates pair-plots for a given joint prior. Parameters ---------- prior : callable The prior object which takes a single integer argument and generates random draws. param_names : list of str or None, optional, default None An optional list of strings which n_samples : int, optional, default: 1000 The number of random draws from the joint prior height : float, optional, default: 2.5 The height of the pair plot color : str, optional, default : '#8f2727' The color of the plot **kwargs : dict, optional Additional keyword arguments passed to the sns.PairGrid constructor Returns ------- f : plt.Figure - the figure instance for optional saving """ # Generate prior draws prior_samples = prior(n_samples) # Handle dict type if type(prior_samples) is dict: prior_samples = prior_samples["prior_draws"] # Get latent dimensionality and prepare titles dim = prior_samples.shape[-1] # Convert samples to a pandas data frame if param_names is None: titles = [f"Prior Param. {i}" for i in range(1, dim + 1)] else: titles = [f"Prior {p}" for p in param_names] data_to_plot = pd.DataFrame(prior_samples, columns=titles) # Generate plots g = sns.PairGrid(data_to_plot, height=height, **kwargs) g.map_diag(sns.histplot, fill=True, color=color, alpha=0.9, kde=True) # Kernel density estimation (KDE) may not always be possible # (e.g. with parameters whose correlation is close to 1 or -1). # In this scenario, a scatter-plot is generated instead. try: g.map_lower(sns.kdeplot, fill=True, color=color, alpha=0.9) except Exception as e: logging.warning("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.") g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) # Add grids for i in range(dim): for j in range(dim): g.axes[i, j].grid(alpha=0.5) g.tight_layout() return g.fig
[docs] def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs): """Creates pair plots for the latent space learned by the inference network. Enables visual inspection of the latent space and whether its structure corresponds to the one enforced by the optimization criterion. Parameters ---------- z_samples : np.ndarray or tf.Tensor of shape (n_sim, n_params) The latent samples computed through a forward pass of the inference network. height : float, optional, default: 2.5 The height of the pair plot. color : str, optional, default : '#8f2727' The color of the plot **kwargs : dict, optional Additional keyword arguments passed to the sns.PairGrid constructor Returns ------- f : plt.Figure - the figure instance for optional saving """ # Try to convert z_samples, if eventually tf.Tensor is passed if type(z_samples) is not np.ndarray: z_samples = z_samples.numpy() # Get latent dimensionality and prepare titles z_dim = z_samples.shape[-1] # Convert samples to a pandas data frame titles = [f"Latent Dim. {i}" for i in range(1, z_dim + 1)] data_to_plot = pd.DataFrame(z_samples, columns=titles) # Generate plots g = sns.PairGrid(data_to_plot, height=height, **kwargs) g.map_diag(sns.histplot, fill=True, color=color, alpha=0.9, kde=True) g.map_lower(sns.kdeplot, fill=True, color=color, alpha=0.9) g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) # Add grids for i in range(z_dim): for j in range(z_dim): g.axes[i, j].grid(alpha=0.5) g.tight_layout() return g.fig
[docs] def plot_calibration_curves( true_models, pred_models, model_names=None, num_bins=10, label_fontsize=16, legend_fontsize=14, title_fontsize=18, tick_fontsize=12, epsilon=0.02, fig_size=None, color="#8f2727", n_row=None, n_col=None, ): """Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin. Depends on the ``expected_calibration_error`` function for computing the ECE. Parameters ---------- true_models : np.ndarray of shape (num_data_sets, num_models) The one-hot-encoded true model indices per data set. pred_models : np.ndarray of shape (num_data_sets, num_models) The predicted posterior model probabilities (PMPs) per data set. model_names : list or None, optional, default: None The model names for nice plot titles. Inferred if None. num_bins : int, optional, default: 10 The number of bins to use for the calibration curves (and marginal histograms). label_fontsize : int, optional, default: 16 The font size of the y-label and y-label texts legend_fontsize : int, optional, default: 14 The font size of the legend text (ECE value) title_fontsize : int, optional, default: 18 The font size of the title text. Only relevant if `stacked=False` tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels epsilon : float, optional, default: 0.02 A small amount to pad the [0, 1]-bounded axes from both side. fig_size : tuple or None, optional, default: None The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` color : str, optional, default: '#8f2727' The color of the calibration curves n_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. n_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. Returns ------- fig : plt.Figure - the figure instance for optional saving """ num_models = true_models.shape[-1] if model_names is None: model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)] # Determine number of rows and columns for subplots based on inputs if n_row is None and n_col is None: n_row = int(np.ceil(num_models / 6)) n_col = int(np.ceil(num_models / n_row)) elif n_row is None and n_col is not None: n_row = int(np.ceil(num_models / n_col)) elif n_row is not None and n_col is None: n_col = int(np.ceil(num_models / n_row)) # Compute calibration cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins) # Initialize figure if fig_size is None: fig_size = (int(5 * n_col), int(5 * n_row)) fig, axarr = plt.subplots(n_row, n_col, figsize=fig_size) if n_row > 1: ax = axarr.flat # Plot marginal calibration curves in a loop if n_row > 1: ax = axarr.flat else: ax = axarr for j in range(num_models): # Plot calibration curve ax[j].plot(probs_pred[j], probs_true[j], "o-", color=color) # Plot PMP distribution over bins uniform_bins = np.linspace(0.0, 1.0, num_bins + 1) norm_weights = np.ones_like(pred_models) / len(pred_models) ax[j].hist(pred_models[:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3) # Plot AB line ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9) # Tweak plot ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize) ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) ax[j].set_title(model_names[j], fontsize=title_fontsize) ax[j].spines["right"].set_visible(False) ax[j].spines["top"].set_visible(False) ax[j].set_xlim([0 - epsilon, 1 + epsilon]) ax[j].set_ylim([0 - epsilon, 1 + epsilon]) ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) ax[j].grid(alpha=0.5) # Add ECE label ax[j].text( 0.1, 0.9, r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}".format(cal_errs[j]), horizontalalignment="left", verticalalignment="center", transform=ax[j].transAxes, size=legend_fontsize, ) # Only add x-labels to the bottom row bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] for _ax in bottom_row: _ax.set_xlabel("Predicted probability", fontsize=label_fontsize) # Only add y-labels to left-most row if n_row == 1: # if there is only one row, the ax array is 1D ax[0].set_ylabel("True probability", fontsize=label_fontsize) else: # if there is more than one row, the ax array is 2D for _ax in axarr[:, 0]: _ax.set_ylabel("True probability", fontsize=label_fontsize) fig.tight_layout() return fig
[docs] def plot_confusion_matrix( true_models, pred_models, model_names=None, fig_size=(5, 5), label_fontsize=16, title_fontsize=18, value_fontsize=10, tick_fontsize=12, xtick_rotation=None, ytick_rotation=None, normalize=True, cmap=None, title=True, ): """Plots a confusion matrix for validating a neural network trained for Bayesian model comparison. Parameters ---------- true_models : np.ndarray of shape (num_data_sets, num_models) The one-hot-encoded true model indices per data set. pred_models : np.ndarray of shape (num_data_sets, num_models) The predicted posterior model probabilities (PMPs) per data set. model_names : list or None, optional, default: None The model names for nice plot titles. Inferred if None. fig_size : tuple or None, optional, default: (5, 5) The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` label_fontsize : int, optional, default: 16 The font size of the y-label and y-label texts title_fontsize : int, optional, default: 18 The font size of the title text. value_fontsize : int, optional, default: 10 The font size of the text annotations and the colorbar tick labels. tick_fontsize : int, optional, default: 12 The font size of the axis label and model name texts. xtick_rotation: int, optional, default: None Rotation of x-axis tick labels (helps with long model names). ytick_rotation: int, optional, default: None Rotation of y-axis tick labels (helps with long model names). normalize : bool, optional, default: True A flag for normalization of the confusion matrix. If True, each row of the confusion matrix is normalized to sum to 1. cmap : matplotlib.colors.Colormap or str, optional, default: None Colormap to be used for the cells. If a str, it should be the name of a registered colormap, e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red. title : bool, optional, default True A flag for adding 'Confusion Matrix' above the matrix. Returns ------- fig : plt.Figure - the figure instance for optional saving """ if model_names is None: num_models = true_models.shape[-1] model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)] if cmap is None: cmap = LinearSegmentedColormap.from_list("", ["white", "#8f2727"]) # Flatten input true_models = np.argmax(true_models, axis=1) pred_models = np.argmax(pred_models, axis=1) # Compute confusion matrix cm = confusion_matrix(true_models, pred_models) if normalize: cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # Initialize figure fig, ax = plt.subplots(1, 1, figsize=fig_size) im = ax.imshow(cm, interpolation="nearest", cmap=cmap) cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75) cbar.ax.tick_params(labelsize=value_fontsize) ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0])) ax.set_xticklabels(model_names, fontsize=tick_fontsize) if xtick_rotation: plt.xticks(rotation=xtick_rotation, ha="right") ax.set_yticklabels(model_names, fontsize=tick_fontsize) if ytick_rotation: plt.yticks(rotation=ytick_rotation) ax.set_xlabel("Predicted model", fontsize=label_fontsize) ax.set_ylabel("True model", fontsize=label_fontsize) # Loop over data dimensions and create text annotations fmt = ".2f" if normalize else "d" thresh = cm.max() / 2.0 for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text( j, i, format(cm[i, j], fmt), fontsize=value_fontsize, ha="center", va="center", color="white" if cm[i, j] > thresh else "black", ) if title: ax.set_title("Confusion Matrix", fontsize=title_fontsize) return fig
[docs] def plot_mmd_hypothesis_test( mmd_null, mmd_observed=None, alpha_level=0.05, null_color=(0.16407, 0.020171, 0.577478), observed_color="red", alpha_color="orange", truncate_vlines_at_kde=False, xmin=None, xmax=None, bw_factor=1.5, ): """ Parameters ---------- mmd_null : np.ndarray The samples from the MMD sampling distribution under the null hypothesis "the model is well-specified" mmd_observed : float The observed MMD value alpha_level : float, optional, default: 0.05 The rejection probability (type I error) null_color : str or tuple, optional, default: (0.16407, 0.020171, 0.577478) The color of the H0 sampling distribution observed_color : str or tuple, optional, default: "red" The color of the observed MMD alpha_color : str or tuple, optional, default: "orange" The color of the rejection area truncate_vlines_at_kde: bool, optional, default: False true: cut off the vlines at the kde false: continue kde lines across the plot xmin : float, optional, default: None The lower x-axis limit xmax : float, optional, default: None The upper x-axis limit bw_factor : float, optional, default: 1.5 bandwidth (aka. smoothing parameter) of the kernel density estimate Returns ------- f : plt.Figure - the figure instance for optional saving """ def draw_vline_to_kde(x, kde_object, color, label=None, **kwargs): kde_x, kde_y = kde_object.lines[0].get_data() idx = np.argmin(np.abs(kde_x - x)) plt.vlines(x=x, ymin=0, ymax=kde_y[idx], color=color, linewidth=3, label=label, **kwargs) def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs): kde_x, kde_y = kde_object.lines[0].get_data() if x_end is not None: plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start) & (kde_x <= x_end), interpolate=True, **kwargs) else: plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start), interpolate=True, **kwargs) f = plt.figure(figsize=(8, 4)) kde = sns.kdeplot(mmd_null, fill=False, linewidth=0, bw_adjust=bw_factor) sns.kdeplot(mmd_null, fill=True, alpha=0.12, color=null_color, bw_adjust=bw_factor) if truncate_vlines_at_kde: draw_vline_to_kde(x=mmd_observed, kde_object=kde, color=observed_color, label=r"Observed data") else: plt.vlines( x=mmd_observed, ymin=0, ymax=plt.gca().get_ylim()[1], color=observed_color, linewidth=3, label=r"Observed data", ) mmd_critical = np.quantile(mmd_null, 1 - alpha_level) fill_area_under_kde( kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level*100)}% rejection area" ) if truncate_vlines_at_kde: draw_vline_to_kde(x=mmd_critical, kde_object=kde, color=alpha_color) else: plt.vlines(x=mmd_critical, color=alpha_color, linewidth=3, ymin=0, ymax=plt.gca().get_ylim()[1]) sns.kdeplot(mmd_null, fill=False, linewidth=3, color=null_color, label=r"$H_0$", bw_adjust=bw_factor) plt.xlabel(r"MMD", fontsize=20) plt.ylabel("") plt.yticks([]) plt.xlim(xmin, xmax) plt.tick_params(axis="both", which="major", labelsize=16) plt.legend(fontsize=20) sns.despine() return f