Source code for bayesflow.diagnostics.plots.recovery_from_estimates

from collections.abc import Callable, Sequence, Mapping

import matplotlib.pyplot as plt
import numpy as np


from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric


[docs] def recovery_from_estimates( estimates: Mapping[str, Mapping[str, np.ndarray]], targets: Mapping[str, np.ndarray], marker_mapping: dict[str, str], variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, add_corr: bool = True, corr_point_agg: Callable = np.median, figsize: Sequence[int] = None, label_fontsize: int = 16, title_fontsize: int = 18, metric_fontsize: int = 16, tick_fontsize: int = 12, color: str = "#132a70", num_col: int = None, num_row: int = None, xlabel: str = "Ground truth", ylabel: str = "Estimate", **kwargs, ) -> plt.Figure: """ Creates and plots publication-ready recovery plot of estimates vs. targets. This plot yields similar information as the "posterior z-score", but allows for generic point and uncertainty estimates: https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html Important: Posterior aggregates play no special role in Bayesian inference and should only be used heuristically. For instance, in the case of multi-modal posteriors, common point estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) need to be interpreted carefully. Parameters ---------- estimates : dict[str, dict[str, np.ndarray]] The model-generated estimates in a nested dictionary, e.g. as returned by approximator.estimate(conditions=...). - The outer keys identify the inference variable. - The inner keys identify point estimates. - The inner value is an ndarray of shape (num_datasets, point_estimate_size, variable_block_size) targets : dict[str, np.ndarray] The prior draws (true parameters) used for generating the num_datasets marker_mapping : dict[str, str] Define how to mark different point estimates by their key, e.g. {"quantiles":"_", "mean":"*"}. Only point estimates whose key appears in the dictionary, will be plotted. variable_keys : list or None, optional, default: None Select keys from the dictionaries provided in estimates and targets. By default, select all keys. variable_names : list or None, optional, default: None The individual parameter names for nice plot titles. Inferred if None add_corr : boolean, default: True Should correlations between estimates and ground truth values be shown? corr_point_agg : Callable Function producing a central point estimate from the whole list of point estimates in case correlations should be computed. Default: median figsize : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 16 The font size of the y-label text. title_fontsize : int, optional, default: 18 The font size of the title text. metric_fontsize : int, optional, default: 16 The font size of the metrics shown as text. tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels. color : str, optional, default: '#8f2727' The color for the true vs. estimated scatter points and error bars. num_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. xlabel: ylabel: Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation from the expected shapes of ``estimates`` and ``targets``. """ point_estimates = {} markers = [] for var_key, var_value in estimates.items(): estimate_list = [] for estimate_key in marker_mapping.keys(): e = estimates[var_key][estimate_key] # move axis of flexible shape to the end e = np.moveaxis(e, 1, -1) # add dimensions of size (1,) if not 3D yet e = np.atleast_3d(e) # undo reordering e = np.moveaxis(e, -1, 1) estimate_list.append(e) markers += [marker_mapping[estimate_key]] * e.shape[1] point_estimates[var_key] = np.concatenate(estimate_list, axis=1) # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=point_estimates, targets=targets, variable_keys=variable_keys, variable_names=variable_names, num_col=num_col, num_row=num_row, figsize=figsize, ) estimates = plot_data.pop("estimates") targets = plot_data.pop("targets") # Compute point estimates and uncertainties point_estimate = corr_point_agg(estimates, axis=1) num_quantiles = estimates.shape[1] for i, ax in enumerate(plot_data["axes"].flat): if i >= plot_data["num_variables"]: break # Add scatter and error bars for q_idx in range(num_quantiles): _ = ax.scatter( targets[:, i], estimates[:, q_idx, i], marker=markers[q_idx], alpha=0.5, color=color, **kwargs, ) connecting_lines_x = np.tile(targets[:, i], (2, 1)) connecting_lines_y = np.array([estimates[:, :, i].min(axis=1), estimates[:, :, i].max(axis=1)]) _ = ax.plot( connecting_lines_x, connecting_lines_y, alpha=0.5, color=color, ) make_quadratic(ax, targets[:, i], point_estimate[:, i]) if add_corr: corr = np.corrcoef(targets[:, i], point_estimate[:, i])[0, 1] add_metric(ax=ax, metric_text="$r$", metric_value=corr, metric_fontsize=metric_fontsize) ax.set_title(plot_data["variable_names"][i], fontsize=title_fontsize) # Add custom schmuck prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize) add_titles_and_labels( axes=plot_data["axes"], num_row=plot_data["num_row"], num_col=plot_data["num_col"], xlabel=xlabel, ylabel=ylabel, label_fontsize=label_fontsize, ) plot_data["fig"].tight_layout() return plot_data["fig"]