Source code for bayesflow.distributions.mixture

from collections.abc import Sequence

import numpy as np

import keras
from keras import ops

from bayesflow.types import Shape, Tensor
from bayesflow.utils.decorators import allow_batch_size
from bayesflow.utils.serialization import serializable, serialize
from bayesflow.distributions import Distribution


[docs] @serializable class Mixture(Distribution): """Utility class for a backend-agnostic mixture distributions.""" def __init__( self, distributions: Sequence[Distribution], mixture_logits: Sequence[float] = None, trainable_mixture: bool = False, **kwargs, ): """ Initializes a mixture of distributions as a latent distro. Parameters ---------- distributions : Sequence[Distribution] A sequence of `Distribution` instances to form the mixture components. mixture_logits : Sequence[float], optional Initial unnormalized log‑weights for each component. If `None`, all components are assigned equal weight. Default is `None`. trainable_mixture : bool, optional Whether the mixture weights (`mixture_logits`) should be trainable. Default is `False`. **kwargs Additional keyword arguments passed to the base `Distribution` class. Attributes ---------- distributions : Sequence[Distribution] The list of component distributions. mixture_logits : Tensor Trainable or fixed logits representing the mixture weights. dim : int or None Dimensionality of the output samples; set when first sampling. """ super().__init__(**kwargs) self.distributions = distributions if mixture_logits is None: self.mixture_logits = ops.ones(shape=len(distributions)) else: self.mixture_logits = ops.convert_to_tensor(mixture_logits) self.trainable_mixture = trainable_mixture self.dim = None self._mixture_logits = None
[docs] @allow_batch_size def sample(self, batch_shape: Shape) -> Tensor: """ Draws samples from the mixture distribution by sampling a categorical index for each entry in `batch_shape` according to the softmax of `mixture_logits`, then draws from the corresponding component distribution. Parameters ---------- batch_shape : Shape The desired sample batch shape (tuple of ints), not including the event dimension. Returns ------- samples: Tensor A tensor of shape `batch_shape + (dim,)` containing samples drawn from the mixture. """ # Will use numpy until keras adds support for N-D categorical sampling pvals = keras.ops.convert_to_numpy(keras.ops.softmax(self._mixture_logits)) cat_samples = np.random.multinomial(n=1, pvals=pvals, size=batch_shape) cat_samples = cat_samples.argmax(axis=-1) # Prepare array to fill and dtype to infer samples = np.zeros(batch_shape + (self.dim,)) dtype = None # Fill in array with vectorized sampling per component for i in range(len(self.distributions)): dist_mask = cat_samples == i dist_indices = np.where(dist_mask) num_dist_samples = np.sum(dist_mask) dist_samples = keras.ops.convert_to_numpy(self.distributions[i].sample(num_dist_samples)) samples[dist_indices] = dist_samples dtype = dtype or keras.ops.dtype(dist_samples) # Convert to keras for compatibility samples = keras.ops.convert_to_tensor(samples, dtype=dtype) return samples
[docs] def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: """ Compute the log probability of given samples under the mixture. For each input sample, computes the weighted log‑sum‑exp of the component log‑probabilities plus the mixture log‑weights. Parameters ---------- samples : Tensor A tensor of samples with shape `batch_shape + (dim,)`. normalize : bool, optional If `True`, returns normalized log‑probabilities (i.e., includes the log normalization constant). Default is `True`. Returns ------- Tensor A tensor of shape `batch_shape`, containing the log probability of each sample under the mixture distribution. """ log_prob = [distribution.log_prob(samples, normalize=normalize) for distribution in self.distributions] log_prob = ops.stack(log_prob, axis=-1) log_prob = ops.logsumexp(log_prob + ops.log_softmax(self._mixture_logits), axis=-1) return log_prob
[docs] def build(self, input_shape: Shape) -> None: if self.built: return self.dim = input_shape[-1] for distribution in self.distributions: distribution.build(input_shape) self._mixture_logits = self.add_weight( shape=(len(self.distributions),), initializer=keras.initializers.get(self.mixture_logits), dtype="float32", trainable=self.trainable_mixture, )
[docs] def get_config(self): base_config = super().get_config() config = { "distributions": self.distributions, "mixture_logits": self.mixture_logits, "trainable_mixture": self.trainable_mixture, } return base_config | serialize(config)