Source code for bayesflow.diagnostics.plots.mc_calibration

from collections.abc import Mapping, Sequence

import numpy as np
import matplotlib.pyplot as plt

from bayesflow.utils import (
    prepare_plot_data,
    add_titles_and_labels,
    add_metric,
    prettify_subplots,
)

from bayesflow.diagnostics.metrics import expected_calibration_error


[docs] def mc_calibration( pred_models: Mapping[str, np.ndarray] | np.ndarray, true_models: Mapping[str, np.ndarray] | np.ndarray, model_names: Sequence[str] = None, num_bins: int = 10, label_fontsize: int = 16, title_fontsize: int = 18, metric_fontsize: int = 14, tick_fontsize: int = 12, epsilon: float = 0.02, figsize: Sequence[int] = None, color: str = "#132a70", num_col: int = None, num_row: int = None, ) -> plt.Figure: """Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin. Depends on the ``expected_calibration_error`` function for computing the ECE. Parameters ---------- pred_models : np.ndarray of shape (num_data_sets, num_models) The predicted posterior model probabilities (PMPs) per data set. true_models : np.ndarray of shape (num_data_sets, num_models) The one-hot-encoded true model indices per data set. model_names : list or None, optional, default: None The model names for nice plot titles. Inferred if None. num_bins : int, optional, default: 10 The number of bins to use for the calibration curves (and marginal histograms). label_fontsize : int, optional, default: 16 The font size of the y-label and y-label texts metric_fontsize : int, optional, default: 14 The font size of the metric (e.g., ECE) title_fontsize : int, optional, default: 18 The font size of the title text. Only relevant if `stacked=False` tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels epsilon : float, optional, default: 0.02 A small amount to pad the [0, 1]-bounded axes from both side. figsize : tuple or None, optional, default: None The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` color : str, optional, default: '#8f2727' The color of the calibration curves num_row : int, optional, default: None The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. Returns ------- fig : plt.Figure - the figure instance for optional saving """ # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=pred_models, targets=true_models, variable_names=model_names, num_col=num_col, num_row=num_row, figsize=figsize, default_name="M", ) # compute ece and probs ece = expected_calibration_error( estimates=pred_models, targets=true_models, model_names=plot_data["variable_names"], num_bins=num_bins, return_probs=True, ) for j, ax in enumerate(plot_data["axes"].flat): # Plot calibration curve ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color) # Plot PMP distribution over bins uniform_bins = np.linspace(0.0, 1.0, num_bins + 1) norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"]) ax.hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3) # Plot AB line ax.plot((0, 1), (0, 1), "--", color="black", alpha=0.9) # Tweak plot ax.set_xlim([0 - epsilon, 1 + epsilon]) ax.set_ylim([0 - epsilon, 1 + epsilon]) ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) # Add ECE label add_metric( ax, metric_text=r"$\widehat{{\mathrm{{ECE}}}}$", metric_value=ece["values"][j], metric_fontsize=metric_fontsize, ) prettify_subplots(axes=plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize) # Add labels, titles, and set font sizes add_titles_and_labels( axes=plot_data["axes"], num_row=plot_data["num_row"], num_col=plot_data["num_col"], title=plot_data["variable_names"], xlabel="Predicted Probability", ylabel="True Probability", title_fontsize=title_fontsize, label_fontsize=label_fontsize, ) plot_data["fig"].tight_layout() return plot_data["fig"]