Source code for bayesflow.diagnostics.plots.pairs_posterior

from collections.abc import Sequence, Mapping

import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import seaborn as sns

from bayesflow.utils.dict_utils import dicts_to_arrays

from .pairs_samples import _pairs_samples


[docs] def pairs_posterior( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray = None, priors: Mapping[str, np.ndarray] | np.ndarray = None, dataset_id: int = None, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, height: int = 3, post_color: str | tuple = "#132a70", prior_color: str | tuple = "gray", alpha: float = 0.9, label_fontsize: int = 14, tick_fontsize: int = 12, legend_fontsize: int = 14, **kwargs, ) -> sns.PairGrid: """Generates a bivariate pair plot given posterior draws and optional prior or prior draws. Parameters ---------- estimates : np.ndarray of shape (n_post_draws, n_params) The posterior draws obtained for a SINGLE observed data set. targets : np.ndarray of shape (n_params,) or None, optional, default: None Optional true parameter values that have generated the observed dataset. priors : np.ndarray of shape (n_prior_draws, n_params) or None, optional (default: None) Optional prior samples obtained from the prior. dataset_id: Optional ID of the dataset for whose posterior the pairs plot shall be generated. Should only be specified if estimates contains posterior draws from multiple datasets. variable_keys : list or None, optional, default: None Select keys from the dictionary provided in samples. By default, select all keys. variable_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None height : float, optional, default: 3 The height of the pairplot label_fontsize : int, optional, default: 14 The font size of the x and y-label texts (parameter names) tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels legend_fontsize : int, optional, default: 16 The font size of the legend text post_color : str, optional, default: '#132a70' The color for the posterior histograms and KDEs prior_color : str, optional, default: gray The color for the optional prior histograms and KDEs alpha : float in [0, 1], optional, default: 0.9 The opacity of the posterior plots **kwargs : dict, optional, default: {} Further optional keyword arguments propagated to `_pairs_samples` Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ AssertionError If the shape of posterior_draws is not 2-dimensional. """ plot_data = dicts_to_arrays( estimates=estimates, targets=targets, priors=priors, dataset_ids=dataset_id, variable_keys=variable_keys, variable_names=variable_names, ) # dicts_to_arrays will keep dataset axis even if it is of length 1 # however, pairs plotting requires the dataset axis to be removed estimates_shape = plot_data["estimates"].shape if len(estimates_shape) == 3 and estimates_shape[0] == 1: plot_data["estimates"] = np.squeeze(plot_data["estimates"], axis=0) # plot posterior first g = _pairs_samples( plot_data=plot_data, height=height, color=post_color, color2=prior_color, alpha=alpha, label_fontsize=label_fontsize, tick_fontsize=tick_fontsize, legend_fontsize=legend_fontsize, **kwargs, ) targets = plot_data.get("targets") if targets is not None: # Ensure targets is at least 2D if targets.ndim == 1: targets = np.atleast_2d(targets) # Create DataFrame with variable names as columns g.data = pd.DataFrame(targets, columns=targets.variable_names) g.data["_source"] = "True Parameter" g.map_diag(plot_true_params) return g
def plot_true_params(x, hue=None, **kwargs): """Custom function to plot true parameters on the diagonal.""" # hue needs to be added to handle the case of plotting both posterior and prior param = x.iloc[0] # Get the single true value for the diagonal # only plot on the diagonal a vertical line for the true parameter plt.axvline(param, color="black", linestyle="--")