Source code for bayesflow.diagnostics.metrics.calibration_error

from collections.abc import Sequence, Mapping, Callable

import numpy as np

from ...utils.dict_utils import dicts_to_arrays


[docs] def calibration_error( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, resolution: int = 20, aggregation: Callable = np.median, min_quantile: float = 0.005, max_quantile: float = 0.995, ) -> dict[str, any]: """ Computes an aggregate score for the marginal calibration error over an ensemble of approximate posteriors. The calibration error is given as the aggregate (e.g., median) of the absolute deviation between an alpha-CI and the relative number of inliers from ``estimates`` over multiple alphas in (0, 1). Parameters ---------- estimates : np.ndarray of shape (num_datasets, num_draws, num_variables) The random draws from the approximate posteriors over ``num_datasets`` targets : np.ndarray of shape (num_datasets, num_variables) The corresponding ground-truth values sampled from the prior variable_keys : Sequence[str], optional (default = None) Select keys from the dictionaries provided in estimates and targets. By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to show in the output. resolution : int, optional, default: 20 The number of credibility intervals (CIs) to consider aggregation : callable or None, optional, default: np.median The function used to aggregate the marginal calibration errors. If ``None`` provided, the per-alpha calibration errors will be returned. min_quantile : float in (0, 1), optional, default: 0.005 The minimum posterior quantile to consider. max_quantile : float in (0, 1), optional, default: 0.995 The maximum posterior quantile to consider. Returns ------- result : dict Dictionary containing: - "values" : float or np.ndarray The aggregated calibration error per variable - "metric_name" : str The name of the metric ("Calibration Error"). - "variable_names" : str The (inferred) variable names. """ samples = dicts_to_arrays( estimates=estimates, targets=targets, variable_keys=variable_keys, variable_names=variable_names, ) # Define alpha values and the corresponding quantile bounds alphas = np.linspace(start=min_quantile, stop=max_quantile, num=resolution) regions = 1 - alphas lowers = regions / 2 uppers = 1 - lowers # Compute quantiles for each alpha, for each dataset and parameter quantiles = np.quantile(samples["estimates"], [lowers, uppers], axis=1) # Shape: (2, resolution, num_datasets, num_params) lower_bounds, upper_bounds = quantiles[0], quantiles[1] # Compute masks for inliers lower_mask = lower_bounds <= samples["targets"][None, ...] upper_mask = upper_bounds >= samples["targets"][None, ...] # Logical AND to identify inliers for each alpha inlier_id = np.logical_and(lower_mask, upper_mask) # Compute the relative number of inliers for each alpha alpha_pred = np.mean(inlier_id, axis=1) # Calculate absolute error between predicted inliers and alpha absolute_errors = np.abs(alpha_pred - alphas[:, None]) # Aggregate errors across alpha error = aggregation(absolute_errors, axis=0) variable_names = samples["estimates"].variable_names return {"values": error, "metric_name": "Calibration Error", "variable_names": variable_names}