Source code for bayesflow.utils.plot_utils

from typing import Sequence, Any, Mapping

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from .validators import check_estimates_prior_shapes
from .dict_utils import dicts_to_arrays


[docs] def prepare_plot_data( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, num_col: int = None, num_row: int = None, figsize: tuple = None, stacked: bool = False, default_name: str = "v", ) -> Mapping[str, Any]: """ Procedural wrapper that encompasses all preprocessing steps, including shape-checking, parameter name generation, layout configuration, figure initialization, and collapsing of axes. Parameters ---------- estimates : dict[str, ndarray] or ndarray The model-generated predictions or estimates, which can take the following forms: - ndarray of shape (num_datasets, num_variables) Point estimates for each dataset, where `num_datasets` is the number of datasets and `num_variables` is the number of variables per dataset. - ndarray of shape (num_datasets, num_draws, num_variables) Posterior samples for each dataset, where `num_datasets` is the number of datasets, `num_draws` is the number of posterior draws, and `num_variables` is the number of variables. targets : dict[str, ndarray] or ndarray, optional (default = None) Ground truth values corresponding to the estimates. Must match the structure and dimensionality of `estimates` in terms of first and last axis. variable_keys : list or None, optional, default: None Select keys from the dictionary provided in samples. By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to act as a filter if dicts provided or actual variable names in case of array args num_col : int Number of columns for the visualization layout num_row : int Number of rows for the visualization layout figsize : tuple, optional, default: None Size of the figure adjusting to the display resolution stacked : bool, optional, default: False Whether the plots are stacked horizontally default_name : str, optional (default = "v") The default name to use for estimates if None provided """ plot_data = dicts_to_arrays( estimates=estimates, targets=targets, variable_keys=variable_keys, variable_names=variable_names, default_name=default_name, ) check_estimates_prior_shapes(plot_data["estimates"], plot_data["targets"]) # store variable information at top level for easy access variable_names = plot_data["estimates"].variable_names num_variables = len(variable_names) plot_data["variable_names"] = variable_names plot_data["num_variables"] = num_variables # Configure layout num_row, num_col = set_layout(num_variables, num_row, num_col, stacked) # Initialize figure fig, axes = make_figure(num_row, num_col, figsize=figsize) plot_data["fig"] = fig plot_data["axes"] = axes plot_data["num_row"] = num_row plot_data["num_col"] = num_col return plot_data
def set_layout(num_total: int, num_row: int = None, num_col: int = None, stacked: bool = False): """ Determine the number of rows and columns in diagnostics visualizations. Parameters ---------- num_total : int Total number of parameters num_row : int, default = None Number of rows for the visualization layout num_col : int, default = None Number of columns for the visualization layout stacked : bool, default = False Boolean that determines whether to stack the plot or not. Returns ------- num_row : int Number of rows for the visualization layout num_col : int Number of columns for the visualization layout """ if stacked: num_row, num_col = 1, 1 else: if num_row is None and num_col is None: num_row = int(np.ceil(num_total / 6)) num_col = int(np.ceil(num_total / num_row)) elif num_row is None and num_col is not None: num_row = int(np.ceil(num_total / num_col)) elif num_row is not None and num_col is None: num_col = int(np.ceil(num_total / num_row)) return num_row, num_col def make_figure(num_row: int = None, num_col: int = None, figsize: tuple = None): """ Initialize a set of figures Parameters ---------- num_row : int Number of rows in a figure num_col : int Number of columns in a figure figsize : tuple Size of the figure adjusting to the display resolution or the user's choice Returns ------- f, axes Initialized figures """ if num_row == 1 and num_col == 1: f, axes = plt.subplots(1, 1, figsize=figsize) else: if figsize is None: figsize = (int(5 * num_col), int(5 * num_row)) f, axes = plt.subplots(num_row, num_col, figsize=figsize) axes = np.atleast_1d(axes) return f, axes
[docs] def add_metric( ax, metric_text: str = None, metric_value: float = None, position: tuple = (0.1, 0.9), metric_fontsize: int = 12, ): """TODO: docstring""" if metric_text is None or metric_value is None: raise ValueError("Metric text and values must be provided to be add this metric.") metric_label = f"{metric_text} = {metric_value:.3f}" ax.text( position[0], position[1], metric_label, ha="left", va="center", transform=ax.transAxes, size=metric_fontsize, )
def add_x_labels( axes: np.ndarray, num_row: int = None, num_col: int = None, xlabel: Sequence[str] | str = None, label_fontsize: int = None, ): """TODO: docstring""" if num_row == 1: bottom_row = axes else: bottom_row = axes[num_row - 1, :] if num_col > 1 else axes for i, ax in enumerate(bottom_row): # If labels are in sequence, set them sequentially. Otherwise, one label fits all. ax.set_xlabel(xlabel if isinstance(xlabel, str) else xlabel[i], fontsize=label_fontsize) def add_y_labels(axes: np.ndarray, num_row: int = None, ylabel: Sequence[str] | str = None, label_fontsize: int = None): """TODO: docstring""" if num_row == 1: # if there is only one row, the ax array is 1D axes[0].set_ylabel(ylabel, fontsize=label_fontsize) # If there is more than one row, the ax array is 2D else: for i, ax in enumerate(axes[:, 0]): # If labels are in sequence, set them sequentially. Otherwise, one label fits all. ax.set_ylabel(ylabel if isinstance(ylabel, str) else ylabel[i], fontsize=label_fontsize) def add_titles(axes: np.ndarray, title: Sequence[str] | str = None, title_fontsize: int = None): for t, ax in zip(title, axes.flat): ax.set_title(t, fontsize=title_fontsize)
[docs] def add_titles_and_labels( axes: np.ndarray, num_row: int = None, num_col: int = None, title: Sequence[str] | str = None, xlabel: Sequence[str] | str = None, ylabel: Sequence[str] | str = None, title_fontsize: int = None, label_fontsize: int = None, ): """ Wrapper function for configuring labels for both axes. """ if title is not None: add_titles(axes, title, title_fontsize) if xlabel is not None: add_x_labels(axes, num_row, num_col, xlabel, label_fontsize) if ylabel is not None: add_y_labels(axes, num_row, ylabel, label_fontsize)
[docs] def prettify_subplots(axes: np.ndarray, num_subplots: int, tick: bool = True, tick_fontsize: int = 12): """TODO: docstring""" for ax in axes.flat: sns.despine(ax=ax) ax.grid(alpha=0.5) if tick: ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) # Remove unused axes entirely for _ax in axes.flat[num_subplots:]: _ax.remove()
[docs] def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray): """ Utility to make a subplots quadratic in order to avoid visual illusions in, e.g., recovery plots. """ lower = min(x_data.min(), y_data.min()) upper = max(x_data.max(), y_data.max()) eps = (upper - lower) * 0.1 ax.set_xlim((lower - eps, upper + eps)) ax.set_ylim((lower - eps, upper + eps)) ax.plot( [ax.get_xlim()[0], ax.get_xlim()[1]], [ax.get_ylim()[0], ax.get_ylim()[1]], color="black", alpha=0.9, linestyle="dashed", )