Source code for bayesflow.networks.inference.coupling.coupling_flow

import keras

from bayesflow.types import Tensor
from bayesflow.utils import (
    find_permutation,
    layer_kwargs,
    weighted_mean,
)
from bayesflow.utils.serialization import serializable, serialize

from .actnorm import ActNorm
from .layers import DualCoupling

from ...inference import InferenceNetwork


[docs] @serializable("bayesflow.networks") class CouplingFlow(InferenceNetwork): """Coupling-based normalizing flow for simulation-based inference. Constructs a deep invertible architecture composed of multiple layers, including ActNorm, learned permutations, and dual coupling layers. Incorporates ideas from [1-5]. The specific transformation applied in the coupling layers is determined by *transform*, while the subnet type can be either an MLP or another callable architecture specified by *subnet*. If *use_actnorm* is ``True``, an ActNorm layer is applied before each coupling layer. The model can be initialised with a base distribution, such as a standard normal, for density estimation. It can also use more flexible distributions, e.g., GMMs for highly multimodal, low-dimensional distributions or Multivariate Student-*t* for heavy-tailed distributions. Parameters ---------- subnet : str or type, optional Architecture for the transformation network. Can be ``"mlp"``, a custom network class, or a ``Layer`` object, e.g., ``bayesflow.networks.MLP(widths=[32, 32])``. Default is ``"mlp"``. depth : int, optional The number of invertible layers in the model. Default is 6. transform : str, optional The type of transformation used in the coupling layers, such as ``"affine"``. Default is ``"affine"``. permutation : str or None, optional The type of permutation applied between layers. Can be ``"orthogonal"``, ``"random"``, ``"swap"``, or ``None`` (no permutation). Default is ``"random"``. use_actnorm : bool, optional Whether to apply ActNorm before each coupling layer. Default is ``True``. base_distribution : str, optional The base probability distribution from which samples are drawn. Default is ``"normal"``. subnet_kwargs : dict[str, any], optional Keyword arguments forwarded to the subnet (e.g., MLP) constructor within each coupling layer, such as hidden sizes or activation choices. transform_kwargs : dict[str, any], optional Keyword arguments forwarded to the affine or spline transforms (e.g., number of bins for splines). **kwargs Additional keyword arguments passed to the base ``InferenceNetwork``. References ---------- [1] Kingma, D. P., & Dhariwal, P. (2018). Glow: Generative flow with invertible 1x1 convolutions. NeurIPS, 31. [2] Durkan, C., et al. (2019). Neural spline flows. NeurIPS, 32. [3] Ardizzone, L., et al. (2020). Conditional invertible neural networks for diverse image-to-image translation. DAGM GCPR (pp. 373-387). [4] Radev, S. T., et al. (2020). BayesFlow: Learning complex stochastic models with invertible neural networks. IEEE TNNLS. [5] Alexanderson, S., & Henter, G. E. (2020). Robust model training and generalisation with Studentising flows. arXiv:2006.06599. """ def __init__( self, subnet: str | type = "mlp", depth: int = 6, transform: str = "affine", permutation: str | None = "random", use_actnorm: bool = True, base_distribution: str = "normal", subnet_kwargs: dict[str, any] = None, transform_kwargs: dict[str, any] = None, **kwargs, ): super().__init__(base_distribution=base_distribution, **kwargs) self.subnet = subnet self.depth = depth self.transform = transform self.permutation = permutation self.use_actnorm = use_actnorm self.invertible_layers = [] for i in range(depth): if use_actnorm: self.invertible_layers.append(ActNorm()) if (p := find_permutation(permutation)) is not None: self.invertible_layers.append(p) self.invertible_layers.append( DualCoupling(subnet, transform, subnet_kwargs=subnet_kwargs, transform_kwargs=transform_kwargs) ) # We only need to do this from coupling flows, since we do not serialize invertible layers self.subnet_kwargs = subnet_kwargs self.transform_kwargs = transform_kwargs # noinspection PyMethodOverriding
[docs] def build(self, xz_shape, conditions_shape=None): for layer in self.invertible_layers: layer.build(xz_shape=xz_shape, conditions_shape=conditions_shape) self.base_distribution.build(xz_shape)
[docs] def get_config(self): base_config = super().get_config() base_config = layer_kwargs(base_config) config = { "subnet": self.subnet, "depth": self.depth, "transform": self.transform, "permutation": self.permutation, "use_actnorm": self.use_actnorm, "base_distribution": self.base_distribution, "subnet_kwargs": self.subnet_kwargs, "transform_kwargs": self.transform_kwargs, } return base_config | serialize(config)
def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: z = x log_det = keras.ops.zeros(keras.ops.shape(x)[:-1]) for layer in self.invertible_layers: z, det = layer(z, conditions=conditions, inverse=False, training=training) log_det += det if density: log_density_latent = self.base_distribution.log_prob(z) log_density = log_density_latent + log_det return z, log_density return z def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: x = z log_det = keras.ops.zeros(keras.ops.shape(z)[:-1]) for layer in reversed(self.invertible_layers): x, det = layer(x, conditions=conditions, inverse=True, training=training) log_det += det if density: log_prob = self.base_distribution.log_prob(z) log_density = log_prob - log_det return x, log_density return x
[docs] def compute_metrics( self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training", **kwargs ) -> dict[str, Tensor]: _, log_density = self( x, conditions=conditions, inverse=False, density=True, training=stage == "training", **kwargs ) loss = weighted_mean(-log_density, sample_weight) return {"loss": loss}