from collections.abc import Sequence, Mapping
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from bayesflow.utils.dict_utils import dicts_to_arrays
from .pairs_samples import _pairs_samples
[docs]
def pairs_posterior(
estimates: Mapping[str, np.ndarray] | np.ndarray,
targets: Mapping[str, np.ndarray] | np.ndarray = None,
priors: Mapping[str, np.ndarray] | np.ndarray = None,
dataset_id: int = None,
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
height: int = 3,
post_color: str | tuple = "#132a70",
prior_color: str | tuple = "gray",
alpha: float = 0.9,
label_fontsize: int = 14,
tick_fontsize: int = 12,
legend_fontsize: int = 14,
**kwargs,
) -> sns.PairGrid:
"""Generates a bivariate pair plot given posterior draws and optional prior or prior draws.
Parameters
----------
estimates : np.ndarray of shape (n_post_draws, n_params)
The posterior draws obtained for a SINGLE observed data set.
targets : np.ndarray of shape (n_params,) or None, optional, default: None
Optional true parameter values that have generated the observed dataset.
priors : np.ndarray of shape (n_prior_draws, n_params) or None, optional (default: None)
Optional prior samples obtained from the prior.
dataset_id: Optional ID of the dataset for whose posterior the pairs plot shall be generated.
Should only be specified if estimates contains posterior draws from multiple datasets.
variable_keys : list or None, optional, default: None
Select keys from the dictionary provided in samples.
By default, select all keys.
variable_names : list or None, optional, default: None
The parameter names for nice plot titles. Inferred if None
height : float, optional, default: 3
The height of the pairplot
label_fontsize : int, optional, default: 14
The font size of the x and y-label texts (parameter names)
tick_fontsize : int, optional, default: 12
The font size of the axis ticklabels
legend_fontsize : int, optional, default: 16
The font size of the legend text
post_color : str, optional, default: '#132a70'
The color for the posterior histograms and KDEs
prior_color : str, optional, default: gray
The color for the optional prior histograms and KDEs
alpha : float in [0, 1], optional, default: 0.9
The opacity of the posterior plots
**kwargs : dict, optional, default: {}
Further optional keyword arguments propagated to `_pairs_samples`
Returns
-------
f : plt.Figure - the figure instance for optional saving
Raises
------
AssertionError
If the shape of posterior_draws is not 2-dimensional.
"""
plot_data = dicts_to_arrays(
estimates=estimates,
targets=targets,
priors=priors,
dataset_ids=dataset_id,
variable_keys=variable_keys,
variable_names=variable_names,
)
# dicts_to_arrays will keep dataset axis even if it is of length 1
# however, pairs plotting requires the dataset axis to be removed
estimates_shape = plot_data["estimates"].shape
if len(estimates_shape) == 3 and estimates_shape[0] == 1:
plot_data["estimates"] = np.squeeze(plot_data["estimates"], axis=0)
# plot posterior first
g = _pairs_samples(
plot_data=plot_data,
height=height,
color=post_color,
color2=prior_color,
alpha=alpha,
label_fontsize=label_fontsize,
tick_fontsize=tick_fontsize,
legend_fontsize=legend_fontsize,
**kwargs,
)
targets = plot_data.get("targets")
if targets is not None:
# Ensure targets is at least 2D
if targets.ndim == 1:
targets = np.atleast_2d(targets)
# Create DataFrame with variable names as columns
g.data = pd.DataFrame(targets, columns=targets.variable_names)
g.data["_source"] = "True Parameter"
g.map_diag(plot_true_params)
return g
def plot_true_params(x, hue=None, **kwargs):
"""Custom function to plot true parameters on the diagonal."""
# hue needs to be added to handle the case of plotting both posterior and prior
param = x.iloc[0] # Get the single true value for the diagonal
# only plot on the diagonal a vertical line for the true parameter
plt.axvline(param, color="black", linestyle="--")