Source code for bayesflow.diagnostics.plots.recovery

from collections.abc import Sequence, Mapping, Callable

import matplotlib.pyplot as plt
import numpy as np

from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric
from bayesflow.utils.numpy_utils import credible_interval
from bayesflow.utils.dict_utils import compute_test_quantities


[docs] def recovery( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, test_quantities: dict[str, Callable] = None, point_agg: Callable = np.median, uncertainty_agg: Callable = credible_interval, point_agg_kwargs: dict = None, uncertainty_agg_kwargs: dict = None, add_corr: bool = True, figsize: Sequence[int] = None, label_fontsize: int = 16, title_fontsize: int = 18, metric_fontsize: int = 16, tick_fontsize: int = 12, color: str = "#132a70", num_col: int = None, num_row: int = None, xlabel: str = "Ground truth", ylabel: str = "Estimate", markersize: float = None, **kwargs, ) -> plt.Figure: """ Creates and plots publication-ready recovery plot with true estimate vs. point estimate + uncertainty. The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate can be controlled with the ``uncertainty_agg`` argument. This plot yields similar information as the "posterior z-score", but allows for generic point and uncertainty estimates: https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html Important: Posterior aggregates play no special role in Bayesian inference and should only be used heuristically. For instance, in the case of multi-modal posteriors, common point estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing. Parameters ---------- estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params) The posterior draws obtained from num_datasets targets : np.ndarray of shape (num_datasets, num_params) The prior draws (true parameters) used for generating the num_datasets variable_keys : list or None, optional, default: None Select keys from the dictionaries provided in estimates and targets. By default, select all keys. variable_names : list or None, optional, default: None The individual parameter names for nice plot titles. Inferred if None test_quantities : dict or None, optional, default: None A dict that maps plot titles to functions that compute test quantities based on estimate/target draws. The dict keys are automatically added to ``variable_keys`` and ``variable_names``. Test quantity functions are expected to accept a dict of draws with shape ``(batch_size, ...)`` as the first (typically only) positional argument and return an NumPy array of shape ``(batch_size,)``. The functions do not have to deal with an additional sample dimension, as appropriate reshaping is done internally. point_agg : callable, optional, default: median Function to compute point estimates. uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95% Function to compute a measure of uncertainty. Can either be the lower and upper uncertainty bounds provided with the shape (2, num_datasets, num_params) or a scalar measure of uncertainty (e.g., the median absolute deviation) with shape (num_datasets, num_params). point_agg_kwargs : Optional dictionary of further arguments passed to point_agg. uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg. For example, to change the coverage probability of credible_interval to 50%, use uncertainty_agg_kwargs = dict(prob=0.5) add_corr : boolean, default: True Should correlations between estimates and ground truth values be shown? figsize : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 16 The font size of the y-label text. title_fontsize : int, optional, default: 18 The font size of the title text. metric_fontsize : int, optional, default: 16 The font size of the metrics shown as text. tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels. color : str, optional, default: '#8f2727' The color for the true vs. estimated scatter points and error bars. num_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. xlabel : str, optional, default: "Ground truth" The label shown on the x-axis. ylabel : str, optional, default: "Estimate" The label shown on the y-axis. markersize : float, optional, default: None The marker size in points. Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation from the expected shapes of ``estimates`` and ``targets``. """ # Optionally, compute and prepend test quantities from draws if test_quantities is not None: updated_data = compute_test_quantities( targets=targets, estimates=estimates, variable_keys=variable_keys, variable_names=variable_names, test_quantities=test_quantities, ) variable_names = updated_data["variable_names"] variable_keys = updated_data["variable_keys"] estimates = updated_data["estimates"] targets = updated_data["targets"] # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=estimates, targets=targets, variable_keys=variable_keys, variable_names=variable_names, num_col=num_col, num_row=num_row, figsize=figsize, ) estimates = plot_data.pop("estimates") targets = plot_data.pop("targets") point_agg_kwargs = point_agg_kwargs or {} uncertainty_agg_kwargs = uncertainty_agg_kwargs or {} # Compute point estimates and uncertainties point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs) if uncertainty_agg is not None: u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs) if u.ndim == 3: # compute lower and upper error u[0, :, :] = point_estimate - u[0, :, :] u[1, :, :] = u[1, :, :] - point_estimate for i, ax in enumerate(plot_data["axes"].flat): if i >= plot_data["num_variables"]: break # Add scatter and error bars if uncertainty_agg is not None: _ = ax.errorbar( targets[:, i], point_estimate[:, i], yerr=u[..., i], fmt="o", alpha=0.5, color=color, markersize=markersize, **kwargs, ) else: _ = ax.scatter( targets[:, i], point_estimate[:, i], alpha=0.5, color=color, s=None if markersize is None else markersize**2, **kwargs, ) make_quadratic(ax, targets[:, i], point_estimate[:, i]) if add_corr: corr = np.corrcoef(targets[:, i], point_estimate[:, i])[0, 1] add_metric(ax=ax, metric_text="$r$", metric_value=corr, metric_fontsize=metric_fontsize) ax.set_title(plot_data["variable_names"][i], fontsize=title_fontsize) # Add custom schmuck prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize) add_titles_and_labels( axes=plot_data["axes"], num_row=plot_data["num_row"], num_col=plot_data["num_col"], xlabel=xlabel, ylabel=ylabel, label_fontsize=label_fontsize, ) plot_data["fig"].tight_layout() return plot_data["fig"]