Source code for bayesflow.diagnostics.plots.mc_confusion_matrix

from collections.abc import Sequence, Mapping

import matplotlib.colors
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np

from ...utils.plot_utils import make_figure
from ...utils.classification import confusion_matrix


[docs] def mc_confusion_matrix( pred_models: Mapping[str, np.ndarray] | np.ndarray, true_models: Mapping[str, np.ndarray] | np.ndarray, model_names: Sequence[str] = None, fig_size: tuple = (5, 5), label_fontsize: int = 16, title_fontsize: int = 18, value_fontsize: int = 10, tick_fontsize: int = 12, xtick_rotation: int = None, ytick_rotation: int = None, normalize: str = None, cmap: matplotlib.colors.Colormap | str = None, title: bool = True, ) -> plt.Figure: """Plots a confusion matrix for validating a neural network trained for Bayesian model comparison. Parameters ---------- pred_models : np.ndarray of shape (num_data_sets, num_models) The predicted posterior model probabilities (PMPs) per data set. true_models : np.ndarray of shape (num_data_sets, num_models) The one-hot-encoded true model indices per data set. model_names : list or None, optional, default: None The model names for nice plot titles. Inferred if None. fig_size : tuple or None, optional, default: (5, 5) The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` label_fontsize : int, optional, default: 16 The font size of the y-label and y-label texts title_fontsize : int, optional, default: 18 The font size of the title text. value_fontsize : int, optional, default: 10 The font size of the text annotations and the colorbar tick labels. tick_fontsize : int, optional, default: 12 The font size of the axis label and model name texts. xtick_rotation: int, optional, default: None Rotation of x-axis tick labels (helps with long model names). ytick_rotation: int, optional, default: None Rotation of y-axis tick labels (helps with long model names). normalize : {'true', 'pred', 'all'}, default=None Passed to confusion matrix. Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. If None, confusion matrix will not be normalized. cmap : matplotlib.colors.Colormap or str, optional, default: None Colormap to be used for the cells. If a str, it should be the name of a registered colormap, e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red. title : bool, optional, default True A flag for adding 'Confusion Matrix' above the matrix. Returns ------- fig : plt.Figure - the figure instance for optional saving """ if model_names is None: num_models = true_models.shape[-1] model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)] if cmap is None: cmap = LinearSegmentedColormap.from_list("", ["white", "#132a70"]) # Flatten input true_models = np.argmax(true_models, axis=1) pred_models = np.argmax(pred_models, axis=1) # Compute confusion matrix cm = confusion_matrix(true_models, pred_models, normalize=normalize) # Initialize figure fig, ax = make_figure(1, 1, figsize=fig_size) ax = ax[0] im = ax.imshow(cm, interpolation="nearest", cmap=cmap) cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75) cbar.ax.tick_params(labelsize=value_fontsize) ax.set_xticks(range(cm.shape[0])) ax.set_xticklabels(model_names, fontsize=tick_fontsize) if xtick_rotation: plt.xticks(rotation=xtick_rotation, ha="right") ax.set_yticks(range(cm.shape[1])) ax.set_yticklabels(model_names, fontsize=tick_fontsize) if ytick_rotation: plt.yticks(rotation=ytick_rotation) ax.set_xlabel("Predicted model", fontsize=label_fontsize) ax.set_ylabel("True model", fontsize=label_fontsize) # Loop over data dimensions and create text annotations fmt = ".2f" if normalize else "d" thresh = cm.max() / 2.0 for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text( j, i, format(cm[i, j], fmt), fontsize=value_fontsize, ha="center", va="center", color="white" if cm[i, j] > thresh else "black", ) if title: ax.set_title("Confusion Matrix", fontsize=title_fontsize) return fig