Source code for bayesflow.distributions.diagonal_normal
import keras
from keras.saving import register_keras_serializable as serializable
import math
import numpy as np
from bayesflow.types import Shape, Tensor
from bayesflow.utils.decorators import allow_batch_size
from .distribution import Distribution
[docs]
@serializable(package="bayesflow.distributions")
class DiagonalNormal(Distribution):
"""Implements a backend-agnostic diagonal Gaussian distribution."""
def __init__(
self,
mean: int | float | np.ndarray | Tensor = 0.0,
std: int | float | np.ndarray | Tensor = 1.0,
use_learnable_parameters: bool = False,
seed_generator: keras.random.SeedGenerator = None,
**kwargs,
):
"""
Initializes a backend-agnostic diagonal Gaussian distribution with optional learnable parameters.
This class represents a Gaussian distribution with a diagonal covariance matrix, allowing for efficient
sampling and density evaluation.
The mean and standard deviation can be specified as fixed values or learned during training. The class also
supports random number generation with an optional seed for reproducibility.
Parameters
----------
mean : int, float, np.ndarray, or Tensor, optional
The mean of the Gaussian distribution. Can be a scalar or a tensor. Default is 0.0.
std : int, float, np.ndarray, or Tensor, optional
The standard deviation of the Gaussian distribution. Can be a scalar or a tensor.
Default is 1.0.
use_learnable_parameters : bool, optional
Whether to treat the mean and standard deviation as learnable parameters. Default is False.
seed_generator : keras.random.SeedGenerator, optional
A Keras seed generator for reproducible random sampling. If None, a new seed
generator is created. Default is None.
**kwargs
Additional keyword arguments passed to the base `Distribution` class.
"""
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.dim = None
self.log_normalization_constant = None
self.use_learnable_parameters = use_learnable_parameters
if seed_generator is None:
seed_generator = keras.random.SeedGenerator()
self.seed_generator = seed_generator
[docs]
def build(self, input_shape: Shape) -> None:
self.dim = int(input_shape[-1])
# convert to tensor and broadcast if necessary
self.mean = keras.ops.broadcast_to(self.mean, (self.dim,))
self.mean = keras.ops.cast(self.mean, "float32")
self.std = keras.ops.broadcast_to(self.std, (self.dim,))
self.std = keras.ops.cast(self.std, "float32")
self.log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - keras.ops.sum(
keras.ops.log(self.std)
)
if self.use_learnable_parameters:
mean = self.mean
self.mean = self.add_weight(
shape=keras.ops.shape(mean),
initializer="zeros",
dtype="float32",
)
self.mean.assign(mean)
std = self.std
self.std = self.add_weight(
shape=keras.ops.shape(std),
initializer="ones",
dtype="float32",
)
self.std.assign(std)
[docs]
def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
result = -0.5 * keras.ops.sum((samples - self.mean) ** 2 / self.std**2, axis=-1)
if normalize:
result += self.log_normalization_constant
return result
[docs]
@allow_batch_size
def sample(self, batch_shape: Shape) -> Tensor:
return self.mean + self.std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator)