Source code for bayesflow.scores.normed_difference_score

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor

from .scoring_rule import ScoringRule


[docs] @serializable(package="bayesflow.scores") class NormedDifferenceScore(ScoringRule): r""":math:`S(\hat \theta, \theta; k) = | \hat \theta - \theta |^k` Scores a point estimate with the k-norm of the error. """ def __init__(self, k: int, **kwargs): super().__init__(**kwargs) #: Exponent to absolute difference self.k = k self.config = {"k": k}
[docs] def get_head_shapes_from_target_shape(self, target_shape: Shape): # keras.saving.load_model sometimes passes target_shape as a list, so we force a conversion target_shape = tuple(target_shape) return dict(value=target_shape[1:])
[docs] def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor = None) -> Tensor: r""" Computes the scoring function based on the absolute difference between **estimates** and **targets**. :math:`S(\hat \theta, \theta; k) = | \hat \theta - \theta |^k` This function extracts the Tensor named ``"value"`` from the **estimates** dictionary and computes the element-wise absolute difference between the estimates and the true targets. The difference is then exponentiated by :py:attr:`k`. The final score is computed using the :py:func:`aggregate()` method, which optionally applies weighting. Parameters ---------- estimates : dict[str, Tensor] A dictionary containing tensors of estimated values. The "value" key must be present. targets : Tensor A tensor of true target values. weights : Tensor, optional A tensor of weights corresponding to each estimate-target pair. If provided, it is used to compute a weighted aggregate score. Returns ------- Tensor The aggregated score based on the element-wise absolute difference raised to the power of `self.k`, optionally weighted. """ estimates = estimates["value"] scores = keras.ops.absolute(estimates - targets) ** self.k score = self.aggregate(scores, weights) return score
[docs] def get_config(self): base_config = super().get_config() return base_config | self.config