Source code for bayesflow.networks.inference.inference_network

from collections.abc import Sequence

import keras

from bayesflow.types import Shape, Tensor
from bayesflow.utils import layer_kwargs, find_distribution
from bayesflow.utils.decorators import allow_batch_size
from bayesflow.utils.serialization import deserialize


[docs] class InferenceNetwork(keras.Layer): """Abstract base class for all inference networks in BayesFlow. An inference network learns a mapping between a data space and a latent space, optionally conditioned on external variables. Concrete subclasses power the different approximation strategies (normalizing flows, diffusion models, flow matching, consistency models, …). Subclassing guide ----------------- To implement a custom inference network, inherit from this class and override **at minimum** the following methods: ``_forward(x, conditions, density, training, **kwargs)`` Map data *x* → latent *z*. When *density* is ``True`` the method must return a tuple ``(z, log_prob)``; otherwise just *z*. ``_inverse(z, conditions, density, training, **kwargs)`` Map latent *z* → data *x*. Same density convention as ``_forward``. ``compute_metrics(x, conditions, sample_weight, stage)`` Compute and return a ``dict[str, Tensor]`` of training metrics. The dict **must** contain at least a ``"loss"`` key. This is where you implement the training objective for your custom inference network. Optionally override: ``build(xz_shape, conditions_shape)`` Allocate weights that depend on the concrete tensor shapes. Call ``super().build(...)`` to build the ``base_distribution`` and trigger a forward pass for shape inference. ``sample(batch_shape, conditions, **kwargs)`` Draw samples from the learned distribution. The default implementation samples from ``base_distribution`` and passes the result through ``_inverse``. ``log_prob(samples, conditions, **kwargs)`` Evaluate the log-density of *samples* under the learned distribution. The default implementation calls ``_forward`` with ``density=True``. Parameters ---------- base_distribution : str, optional Identifier for the base (latent) distribution, resolved via :func:`~bayesflow.utils.find_distribution`. Default is ``"normal"``. **kwargs Forwarded to ``keras.Layer`` after filtering with :func:`~bayesflow.utils.layer_kwargs`. """ # Valid mask keys to pass to subnet _SUBNET_MASK_KEYS = {"attention_mask", "mask"} def __init__(self, base_distribution: str = "normal", **kwargs): super().__init__(**layer_kwargs(kwargs)) self.base_distribution = find_distribution(base_distribution) @staticmethod def _collect_mask_kwargs(keys: Sequence[str], source: dict) -> dict: """Extract mask kwargs from source dict. Looks up each key in *keys* and includes it in the result if its value is not ``None``. """ return {key: source[key] for key in keys if source.get(key) is not None}
[docs] def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: if self.built: # building when the network is already built can cause issues with serialization # see https://github.com/keras-team/keras/issues/21147 return self.base_distribution.build(xz_shape) x = keras.ops.zeros(xz_shape) conditions = keras.ops.zeros(conditions_shape) if conditions_shape is not None else None self.call(x, conditions, training=True)
[docs] def call( self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: if inverse: return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs) return self._forward(xz, conditions=conditions, density=density, training=training, **kwargs)
def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError
[docs] @allow_batch_size def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> Tensor: samples = self.base_distribution.sample(batch_shape) samples = self(samples, conditions=conditions, inverse=True, density=False, **kwargs) return samples
[docs] def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: _, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs) return log_density
[docs] def compute_metrics( self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training", **kwargs ) -> dict[str, Tensor]: raise NotImplementedError
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))