Source code for bayesflow.diagnostics.plots.pairs_samples

from collections.abc import Sequence, Mapping

import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt

from bayesflow.utils import logging
from bayesflow.utils.dict_utils import dicts_to_arrays


[docs] def pairs_samples( samples: Mapping[str, np.ndarray] | np.ndarray = None, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, height: float = 2.5, color: str | tuple = "#132a70", alpha: float = 0.9, label_fontsize: int = 14, tick_fontsize: int = 12, **kwargs, ) -> sns.PairGrid: """ A more flexible pair plot function for multiple distributions based upon collected samples. Parameters ---------- samples : dict[str, Tensor], default: None Sample draws from any dataset variable_keys : list or None, optional, default: None Select keys from the dictionary provided in samples. By default, select all keys. variable_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None height : float, optional, default: 2.5 The height of the pair plot color : str, optional, default : '#8f2727' The color of the plot alpha : float in [0, 1], optional, default: 0.9 The opacity of the plot label_fontsize : int, optional, default: 14 The font size of the x and y-label texts (parameter names) tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels **kwargs : dict, optional Additional keyword arguments passed to the sns.PairGrid constructor """ plot_data = dicts_to_arrays( estimates=samples, variable_keys=variable_keys, variable_names=variable_names, ) g = _pairs_samples( plot_data=plot_data, height=height, color=color, alpha=alpha, label_fontsize=label_fontsize, tick_fontsize=tick_fontsize, ) return g
def _pairs_samples( plot_data: dict, height: float = 2.5, color: str | tuple = "#132a70", color2: str | tuple = "gray", alpha: float = 0.9, label_fontsize: int = 14, tick_fontsize: int = 12, legend_fontsize: int = 14, **kwargs, ) -> sns.PairGrid: # internal version of pairs_samples creating the seaborn plot # Parameters # ---------- # plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays # other arguments are documented in pairs_samples estimates_shape = plot_data["estimates"].shape if len(estimates_shape) != 2: raise ValueError( f"Samples for a single distribution should be a matrix, but " f"your samples array has a shape of {estimates_shape}." ) variable_names = plot_data["estimates"].variable_names # Convert samples to pd.DataFrame if plot_data["priors"] is not None: # differentiate posterior from prior draws # row bind posterior and prior draws samples = np.vstack((plot_data["priors"], plot_data["estimates"])) data_to_plot = pd.DataFrame(samples, columns=variable_names) # ensure that the source of the samples is stored source_prior = np.repeat("Prior", plot_data["priors"].shape[0]) source_post = np.repeat("Posterior", plot_data["estimates"].shape[0]) data_to_plot["_source"] = np.concatenate((source_prior, source_post)) data_to_plot["_source"] = pd.Categorical(data_to_plot["_source"], categories=["Prior", "Posterior"]) # initialize plot g = sns.PairGrid( data_to_plot, height=height, hue="_source", palette=[color2, color], **kwargs, ) # ensures that color doesn't overwrite palette color = None else: # plot just the one set of distributions data_to_plot = pd.DataFrame(plot_data["estimates"], columns=variable_names) # initialize plot g = sns.PairGrid(data_to_plot, height=height, **kwargs) # add histograms + KDEs to the diagonal g.map_diag( histplot_twinx, fill=True, kde=True, color=color, alpha=alpha, stat="density", common_norm=False, ) # add scatterplots to the upper diagonal g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) # add KDEs to the lower diagonal try: g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) except Exception as e: logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.") g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) # need to add legend here such that colors are recognized if plot_data["priors"] is not None: g.add_legend(fontsize=legend_fontsize, loc="center right") g._legend.set_title(None) # Generate grids dim = g.axes.shape[0] for i in range(dim): for j in range(dim): g.axes[i, j].grid(alpha=0.5) g.axes[i, j].set_axisbelow(True) dim = g.axes.shape[0] for i in range(dim): # Modify tick sizes for j in range(i + 1): g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) # adjust font size of labels # the labels themselves remain the same as before, i.e., variable_names g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize) g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize) # Return figure g.tight_layout() return g # create a histogram plot on a twin y axis # this ensures that the y scaling of the diagonal plots # in independent of the y scaling of the off-diagonal plots def histplot_twinx(x, **kwargs): # Create a twin axis ax2 = plt.gca().twinx() # create a histogram on the twin axis sns.histplot(x, **kwargs, ax=ax2) # make the twin axis invisible plt.gca().spines["right"].set_visible(False) plt.gca().spines["top"].set_visible(False) ax2.set_ylabel("") ax2.set_yticks([]) ax2.set_yticklabels([]) return None