Source code for bayesflow.diagnostics.plots.recovery

from collections.abc import Sequence, Mapping

import matplotlib.pyplot as plt
import numpy as np

from scipy.stats import median_abs_deviation

from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric


[docs] def recovery( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, point_agg=np.median, uncertainty_agg=median_abs_deviation, add_corr: bool = True, figsize: Sequence[int] = None, label_fontsize: int = 16, title_fontsize: int = 18, metric_fontsize: int = 16, tick_fontsize: int = 12, color: str = "#132a70", num_col: int = None, num_row: int = None, xlabel: str = "Ground truth", ylabel: str = "Estimate", **kwargs, ) -> plt.Figure: """ Creates and plots publication-ready recovery plot with true estimate vs. point estimate + uncertainty. The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate can be controlled with the ``uncertainty_agg`` argument. This plot yields similar information as the "posterior z-score", but allows for generic point and uncertainty estimates: https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html Important: Posterior aggregates play no special role in Bayesian inference and should only be used heuristically. For instance, in the case of multi-modal posteriors, common point estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing. Parameters ---------- estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params) The posterior draws obtained from num_datasets targets : np.ndarray of shape (num_datasets, num_params) The prior draws (true parameters) used for generating the num_datasets variable_keys : list or None, optional, default: None Select keys from the dictionaries provided in estimates and targets. By default, select all keys. variable_names : list or None, optional, default: None The individual parameter names for nice plot titles. Inferred if None point_agg : function to compute point estimates. Default: median uncertainty_agg : function to compute uncertainty estimates. Default: MAD add_corr : boolean, default: True Should correlations between estimates and ground truth values be shown? figsize : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 16 The font size of the y-label text. title_fontsize : int, optional, default: 18 The font size of the title text. metric_fontsize : int, optional, default: 16 The font size of the metrics shown as text. tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels. color : str, optional, default: '#8f2727' The color for the true vs. estimated scatter points and error bars. num_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. xlabel: ylabel: Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation from the expected shapes of ``estimates`` and ``targets``. """ # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=estimates, targets=targets, variable_keys=variable_keys, variable_names=variable_names, num_col=num_col, num_row=num_row, figsize=figsize, ) estimates = plot_data.pop("estimates") targets = plot_data.pop("targets") # Compute point estimates and uncertainties point_estimate = point_agg(estimates, axis=1) if uncertainty_agg is not None: u = uncertainty_agg(estimates, axis=1) for i, ax in enumerate(plot_data["axes"].flat): if i >= plot_data["num_variables"]: break # Add scatter and error bars if uncertainty_agg is not None: _ = ax.errorbar( targets[:, i], point_estimate[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color, **kwargs, ) else: _ = ax.scatter(targets[:, i], point_estimate[:, i], alpha=0.5, color=color, **kwargs) make_quadratic(ax, targets[:, i], point_estimate[:, i]) if add_corr: corr = np.corrcoef(targets[:, i], point_estimate[:, i])[0, 1] add_metric(ax=ax, metric_text="$r$", metric_value=corr, metric_fontsize=metric_fontsize) ax.set_title(plot_data["variable_names"][i], fontsize=title_fontsize) # Add custom schmuck prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize) add_titles_and_labels( axes=plot_data["axes"], num_row=plot_data["num_row"], num_col=plot_data["num_col"], xlabel=xlabel, ylabel=ylabel, label_fontsize=label_fontsize, ) plot_data["fig"].tight_layout() return plot_data["fig"]