Source code for bayesflow.utils.classification.confusion_matrix

from typing import Sequence

import numpy as np


[docs] def confusion_matrix(targets: np.ndarray, estimates: np.ndarray, labels: Sequence = None, normalize: str = None): """ Compute confusion matrix to evaluate the accuracy of a classification or model comparison setting. Code inspired by: https://github.com/scikit-learn/scikit-learn/blob/98ed9dc73/sklearn/metrics/_classification.py Parameters ---------- targets : np.ndarray Ground truth (correct) target values. estimates : np.ndarray Estimated targets as returned by a classifier. labels : Sequence, optional List of labels to index the matrix. This may be used to reorder or select a subset of labels. If None, labels that appear at least once in y_true or y_pred are used in sorted order. normalize : {'true', 'pred', 'all'}, optional Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. If None, no normalization is applied. Returns ------- cm : np.ndarray of shape (num_labels, num_labels) Confusion matrix. Rows represent true classes, columns represent predicted classes. """ # Get unique labels if labels is None: labels = np.unique(np.concatenate((targets, estimates))) else: labels = np.asarray(labels) label_to_index = {label: i for i, label in enumerate(labels)} num_labels = len(labels) # Initialize the confusion matrix cm = np.zeros((num_labels, num_labels), dtype=np.int64) # Fill confusion matrix for t, p in zip(targets, estimates): if t in label_to_index and p in label_to_index: cm[label_to_index[t], label_to_index[p]] += 1 # Normalize if required if normalize == "true": with np.errstate(all="ignore"): cm = cm.astype(np.float64) cm = np.divide(cm, cm.sum(axis=1, keepdims=True), where=cm.sum(axis=1, keepdims=True) != 0) elif normalize == "pred": with np.errstate(all="ignore"): cm = cm.astype(np.float64) cm = np.divide(cm, cm.sum(axis=0, keepdims=True), where=cm.sum(axis=0, keepdims=True) != 0) elif normalize == "all": cm = cm.astype(np.float64) cm /= cm.sum() return cm