Source code for bayesflow.utils.plot_utils

from typing import Sequence, Any, Mapping

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
from matplotlib.patches import Rectangle
from matplotlib.legend_handler import HandlerPatch

from .validators import check_estimates_prior_shapes
from .dict_utils import dicts_to_arrays


[docs] def prepare_plot_data( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, num_col: int = None, num_row: int = None, figsize: tuple = None, stacked: bool = False, default_name: str = "v", ) -> Mapping[str, Any]: """ Procedural wrapper that encompasses all preprocessing steps, including shape-checking, parameter name generation, layout configuration, figure initialization, and collapsing of axes. Parameters ---------- estimates : dict[str, ndarray] or ndarray The model-generated predictions or estimates, which can take the following forms: - ndarray of shape (num_datasets, num_variables) Point estimates for each dataset, where `num_datasets` is the number of datasets and `num_variables` is the number of variables per dataset. - ndarray of shape (num_datasets, num_draws, num_variables) Posterior samples for each dataset, where `num_datasets` is the number of datasets, `num_draws` is the number of posterior draws, and `num_variables` is the number of variables. targets : dict[str, ndarray] or ndarray, optional (default = None) Ground truth values corresponding to the estimates. Must match the structure and dimensionality of `estimates` in terms of first and last axis. variable_keys : list or None, optional, default: None Select keys from the dictionary provided in samples. By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to act as a filter if dicts provided or actual variable names in case of array args num_col : int Number of columns for the visualization layout num_row : int Number of rows for the visualization layout figsize : tuple, optional, default: None Size of the figure adjusting to the display resolution stacked : bool, optional, default: False Whether the plots are stacked horizontally default_name : str, optional (default = "v") The default name to use for estimates if None provided """ plot_data = dicts_to_arrays( estimates=estimates, targets=targets, variable_keys=variable_keys, variable_names=variable_names, default_name=default_name, ) check_estimates_prior_shapes(plot_data["estimates"], plot_data["targets"]) # store variable information at top level for easy access variable_names = plot_data["estimates"].variable_names num_variables = len(variable_names) plot_data["variable_names"] = variable_names plot_data["num_variables"] = num_variables # Configure layout num_row, num_col = set_layout(num_variables, num_row, num_col, stacked) # Initialize figure fig, axes = make_figure(num_row, num_col, figsize=figsize) plot_data["fig"] = fig plot_data["axes"] = axes plot_data["num_row"] = num_row plot_data["num_col"] = num_col return plot_data
def set_layout(num_total: int, num_row: int = None, num_col: int = None, stacked: bool = False): """ Determine the number of rows and columns in diagnostics visualizations. Parameters ---------- num_total : int Total number of parameters num_row : int, default = None Number of rows for the visualization layout num_col : int, default = None Number of columns for the visualization layout stacked : bool, default = False Boolean that determines whether to stack the plot or not. Returns ------- num_row : int Number of rows for the visualization layout num_col : int Number of columns for the visualization layout """ if stacked: num_row, num_col = 1, 1 else: if num_row is None and num_col is None: num_row = int(np.ceil(num_total / 6)) num_col = int(np.ceil(num_total / num_row)) elif num_row is None and num_col is not None: num_row = int(np.ceil(num_total / num_col)) elif num_row is not None and num_col is None: num_col = int(np.ceil(num_total / num_row)) return num_row, num_col def make_figure(num_row: int = None, num_col: int = None, figsize: tuple = None): """ Initialize a set of figures Parameters ---------- num_row : int Number of rows in a figure num_col : int Number of columns in a figure figsize : tuple Size of the figure adjusting to the display resolution or the user's choice Returns ------- f, axes Initialized figures """ if num_row == 1 and num_col == 1: f, axes = plt.subplots(1, 1, figsize=figsize) else: if figsize is None: figsize = (int(5 * num_col), int(5 * num_row)) f, axes = plt.subplots(num_row, num_col, figsize=figsize) axes = np.atleast_1d(axes) return f, axes
[docs] def add_metric( ax, metric_text: str = None, metric_value: float = None, position: tuple = (0.1, 0.9), metric_fontsize: int = 12, ): """TODO: docstring""" if metric_text is None or metric_value is None: raise ValueError("Metric text and values must be provided to be add this metric.") metric_label = f"{metric_text} = {metric_value:.3f}" ax.text( position[0], position[1], metric_label, ha="left", va="center", transform=ax.transAxes, size=metric_fontsize, )
def add_x_labels( axes: np.ndarray, num_row: int = None, num_col: int = None, xlabel: Sequence[str] | str = None, label_fontsize: int = None, ): """TODO: docstring""" if num_row == 1: bottom_row = axes else: bottom_row = axes[num_row - 1, :] if num_col > 1 else axes for i, ax in enumerate(bottom_row): # If labels are in sequence, set them sequentially. Otherwise, one label fits all. ax.set_xlabel(xlabel if isinstance(xlabel, str) else xlabel[i], fontsize=label_fontsize) def add_y_labels(axes: np.ndarray, num_row: int = None, ylabel: Sequence[str] | str = None, label_fontsize: int = None): """TODO: docstring""" if num_row == 1: # if there is only one row, the ax array is 1D axes[0].set_ylabel(ylabel, fontsize=label_fontsize) # If there is more than one row, the ax array is 2D else: for i, ax in enumerate(axes[:, 0]): # If labels are in sequence, set them sequentially. Otherwise, one label fits all. ax.set_ylabel(ylabel if isinstance(ylabel, str) else ylabel[i], fontsize=label_fontsize) def add_titles(axes: np.ndarray, title: Sequence[str] | str = None, title_fontsize: int = None): for t, ax in zip(title, axes.flat): ax.set_title(t, fontsize=title_fontsize)
[docs] def add_titles_and_labels( axes: np.ndarray, num_row: int = None, num_col: int = None, title: Sequence[str] | str = None, xlabel: Sequence[str] | str = None, ylabel: Sequence[str] | str = None, title_fontsize: int = None, label_fontsize: int = None, ): """ Wrapper function for configuring labels for both axes. """ if title is not None: add_titles(axes, title, title_fontsize) if xlabel is not None: add_x_labels(axes, num_row, num_col, xlabel, label_fontsize) if ylabel is not None: add_y_labels(axes, num_row, ylabel, label_fontsize)
[docs] def prettify_subplots(axes: np.ndarray, num_subplots: int, tick: bool = True, tick_fontsize: int = 12): """TODO: docstring""" for ax in axes.flat: sns.despine(ax=ax) ax.grid(alpha=0.5) if tick: ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) # Remove unused axes entirely for _ax in axes.flat[num_subplots:]: _ax.remove()
[docs] def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray): """ Utility to make a subplots quadratic in order to avoid visual illusions in, e.g., recovery plots. """ lower = min(x_data.min(), y_data.min()) upper = max(x_data.max(), y_data.max()) eps = (upper - lower) * 0.1 ax.set_xlim((lower - eps, upper + eps)) ax.set_ylim((lower - eps, upper + eps)) ax.plot( [ax.get_xlim()[0], ax.get_xlim()[1]], [ax.get_ylim()[0], ax.get_ylim()[1]], color="black", alpha=0.9, linestyle="dashed", )
def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None): """ Plot a 1D line with color gradient determined by `c` (same shape as x and y). """ if ax is None: ax = plt.gca() # Default color value = y if c is None: c = y # Create segments for LineCollection points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) norm = Normalize(np.min(c), np.max(c)) lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw, alpha=alpha) ax.add_collection(lc) ax.set_xlim(np.min(x), np.max(x)) ax.set_ylim(np.min(y), np.max(y)) return lc def gradient_legend(ax, label, cmap, norm, loc="upper right"): """ Adds a single gradient swatch to the legend of the given Axes. Parameters ---------- - ax: matplotlib Axes - label: str, label to display in the legend - cmap: matplotlib colormap - norm: matplotlib Normalize object - loc: legend location (default 'upper right') """ # Custom dummy handle to represent the gradient class _GradientSwatch(Rectangle): pass # Custom legend handler that draws a horizontal gradient class _HandlerGradient(HandlerPatch): def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans): gradient = np.linspace(0, 1, 256).reshape(1, -1) im = ax.imshow( gradient, aspect="auto", extent=[xdescent, xdescent + width, ydescent, ydescent + height], transform=trans, cmap=cmap, norm=norm, ) return [im] # Add to existing legend entries handles, labels = ax.get_legend_handles_labels() handles.append(_GradientSwatch((0, 0), 1, 1)) labels.append(label) ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()}) def add_gradient_plot( x, y, ax, cmap: str = "viridis", lw: float = 3.0, marker: bool = True, marker_type: str = "o", marker_size: int = 34, alpha: float = 1, label: str = "Validation", ): gradient_line(x, y, c=x, cmap=cmap, lw=lw, alpha=alpha, ax=ax) # Optionally add markers if marker: ax.scatter( x, y, c=x, cmap=cmap, marker=marker_type, s=marker_size, zorder=10, edgecolors="none", label=label, alpha=0.01, )