import keras
from keras.saving import register_keras_serializable as serializable
import math
import numpy as np
from bayesflow.types import Shape, Tensor
from bayesflow.utils import expand_tile
from bayesflow.utils.decorators import allow_batch_size
from .distribution import Distribution
[docs]
@serializable(package="bayesflow.distributions")
class DiagonalStudentT(Distribution):
"""Implements a backend-agnostic diagonal Student-t distribution."""
def __init__(
self,
df: int | float,
loc: int | float | np.ndarray | Tensor = 0.0,
scale: int | float | np.ndarray | Tensor = 1.0,
use_learnable_parameters: bool = False,
seed_generator: keras.random.SeedGenerator = None,
**kwargs,
):
"""
Initializes a backend-agnostic Student's t-distribution with optional learnable parameters.
This class represents a Student's t-distribution, which is useful for modeling heavy-tailed data.
The distribution is parameterized by degrees of freedom (`df`), location (`loc`), and scale (`scale`).
These parameters can either be fixed or learned during training.
The class also supports random number generation with an optional seed for reproducibility.
Parameters
----------
df : int or float
Degrees of freedom for the Student's t-distribution. Lower values result in
heavier tails, making it more robust to outliers.
loc : int, float, np.ndarray, or Tensor, optional
The location parameter (mean) of the distribution. Default is 0.0.
scale : int, float, np.ndarray, or Tensor, optional
The scale parameter (standard deviation) of the distribution. Default is 1.0.
use_learnable_parameters : bool, optional
Whether to treat `loc` and `scale` as learnable parameters. Default is False.
seed_generator : keras.random.SeedGenerator, optional
A Keras seed generator for reproducible random sampling. If None, a new seed
generator is created. Default is None.
**kwargs
Additional keyword arguments passed to the base `Distribution` class.
"""
super().__init__(**kwargs)
self.df = df
self.loc = loc
self.scale = scale
self.dim = None
self.log_normalization_constant = None
self.use_learnable_parameters = use_learnable_parameters
if seed_generator is None:
seed_generator = keras.random.SeedGenerator()
self.seed_generator = seed_generator
[docs]
def build(self, input_shape: Shape) -> None:
self.dim = int(input_shape[-1])
# convert to tensor and broadcast if necessary
self.loc = keras.ops.broadcast_to(self.loc, (self.dim,))
self.loc = keras.ops.cast(self.loc, "float32")
self.scale = keras.ops.broadcast_to(self.scale, (self.dim,))
self.scale = keras.ops.cast(self.scale, "float32")
self.log_normalization_constant = (
-0.5 * self.dim * math.log(self.df)
- 0.5 * self.dim * math.log(math.pi)
- math.lgamma(0.5 * self.df)
+ math.lgamma(0.5 * (self.df + self.dim))
- keras.ops.sum(keras.ops.log(self.scale))
)
if self.use_learnable_parameters:
loc = self.loc
self.loc = self.add_weight(
shape=keras.ops.shape(loc),
initializer="zeros",
dtype="float32",
)
self.loc.assign(loc)
scale = self.scale
self.scale = self.add_weight(
shape=keras.ops.shape(scale),
initializer="ones",
dtype="float32",
)
self.scale.assign(scale)
[docs]
def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
mahalanobis_term = keras.ops.sum((samples - self.loc) ** 2 / self.scale**2, axis=-1)
result = -0.5 * (self.df + self.dim) * keras.ops.log1p(mahalanobis_term / self.df)
if normalize:
result += self.log_normalization_constant
return result
[docs]
@allow_batch_size
def sample(self, batch_shape: Shape) -> Tensor:
# As of writing this code, keras does not support the chi-square distribution
# nor does it support a scale or rate parameter in Gamma. Hence, we use the relation:
# chi-square(df) = Gamma(shape = 0.5 * df, scale = 2) = Gamma(shape = 0.5 * df, scale = 1) * 2
chi2_samples = keras.random.gamma(batch_shape, alpha=0.5 * self.df, seed=self.seed_generator) * 2.0
# The chi-quare samples need to be repeated across self.dim
# since for each element of batch_shape only one sample is created.
chi2_samples = expand_tile(chi2_samples, n=self.dim, axis=-1)
normal_samples = keras.random.normal(batch_shape + (self.dim,), seed=self.seed_generator)
return self.loc + self.scale * normal_samples * keras.ops.sqrt(self.df / chi2_samples)