Source code for bayesflow.scores.multivariate_normal_score

import math

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.links import PositiveDefinite

from .parametric_distribution_score import ParametricDistributionScore


[docs] @serializable(package="bayesflow.scores") class MultivariateNormalScore(ParametricDistributionScore): r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))` Scores a predicted mean and covariance matrix with the log-score of the probability of the materialized value. """ NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("covariance",) """ Marks head for covariance matrix as an exception for adapter transformations. This variable contains names of prediction heads that should lead to a warning when the adapter is applied in inverse direction to them. For more information see :py:class:`ScoringRule`. """ def __init__(self, dim: int = None, links: dict = None, **kwargs): super().__init__(links=links, **kwargs) self.dim = dim self.links = links or {"covariance": PositiveDefinite()} self.config = {"dim": dim}
[docs] def get_config(self): base_config = super().get_config() return base_config | self.config
[docs] def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]: self.dim = target_shape[-1] return dict(mean=(self.dim,), covariance=(self.dim, self.dim))
[docs] def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor: """ Compute the log probability density of a multivariate Gaussian distribution. This function calculates the log probability density for each sample in `x` under a multivariate Gaussian distribution with the given `mean` and `covariance`. The computation includes the determinant of the covariance matrix, its inverse, and the quadratic form in the exponential term of the Gaussian density function. Parameters ---------- x : Tensor A tensor of input samples for which the log probability density is computed. The shape should be compatible with broadcasting against `mean`. mean : Tensor A tensor representing the mean of the multivariate Gaussian distribution. covariance : Tensor A tensor representing the covariance matrix of the multivariate Gaussian distribution. Returns ------- Tensor A tensor containing the log probability densities for each sample in `x` under the given Gaussian distribution. """ diff = x - mean precision = keras.ops.inv(covariance) log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part # Compute the quadratic term in the exponential of the multivariate Gaussian quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff) # Compute the log probability density log_prob = -0.5 * (self.dim * keras.ops.log(2 * math.pi) + log_det_covariance + quadratic_term) return log_prob
[docs] def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor: """ Generate samples from a multivariate Gaussian distribution. This function samples from a multivariate Gaussian distribution with the given `mean` and `covariance` using the Cholesky decomposition method. Independent standard normal samples are transformed using the Cholesky factor of the covariance matrix to generate correlated samples. Parameters ---------- batch_shape : Shape A tuple specifying the batch size and the number of samples to generate. mean : Tensor A tensor representing the mean of the multivariate Gaussian distribution. Must have shape (batch_size, D), where D is the dimensionality of the distribution. covariance : Tensor A tensor representing the covariance matrix of the multivariate Gaussian distribution. Must have shape (batch_size, D, D), where D is the dimensionality. Returns ------- Tensor A tensor of shape (batch_size, num_samples, D) containing the generated samples. """ if len(batch_shape) == 1: batch_shape = (1,) + tuple(batch_shape) batch_size, num_samples = batch_shape dim = keras.ops.shape(mean)[-1] if keras.ops.shape(mean) != (batch_size, dim): raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}") if keras.ops.shape(covariance) != (batch_size, dim, dim): raise ValueError( f"covariance must have shape (batch_size, {dim}, {dim}), but got {keras.ops.shape(covariance)}" ) # Use Cholesky decomposition to generate samples cholesky_factor = keras.ops.cholesky(covariance) normal_samples = keras.random.normal((*batch_shape, dim)) scaled_normal = keras.ops.einsum("ijk,ilk->ilj", cholesky_factor, normal_samples) samples = mean[:, None, :] + scaled_normal return samples