Source code for bayesflow.distributions.distribution
import keras
from bayesflow.types import Shape, Tensor
from bayesflow.utils import layer_kwargs
from bayesflow.utils.serialization import serializable, deserialize
[docs]
@serializable("bayesflow.distributions")
class Distribution(keras.Layer):
def __init__(self, **kwargs):
super().__init__(**layer_kwargs(kwargs))
[docs]
def call(self, samples: Tensor) -> Tensor:
return keras.ops.exp(self.log_prob(samples))
[docs]
def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
raise NotImplementedError
[docs]
def sample(self, batch_shape: Shape, seed: int | keras.random.SeedGenerator | None = None) -> Tensor:
"""Draw samples from the distribution.
Parameters
----------
batch_shape : Shape
The desired sample batch shape (tuple of ints), not including the
event dimension.
seed : int, keras.random.SeedGenerator, or None, optional
Seed for reproducible sampling. An integer is converted to a
``keras.random.SeedGenerator`` and shared across all random draws in
the call. A ``SeedGenerator`` is passed through as-is, advancing its
state with each use. If ``None`` (default), the instance seed
generator is used.
Returns
-------
Tensor
Samples with shape ``batch_shape + (event_dim,)``.
"""
raise NotImplementedError
[docs]
def compute_output_shape(self, input_shape: Shape) -> Shape:
return keras.ops.shape(self.sample(input_shape[0:1]))
[docs]
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))