Source code for bayesflow.networks.mlp.mlp

from collections.abc import Sequence
from typing import Literal, Callable

import keras

from bayesflow.utils import sequential_kwargs
from bayesflow.utils.serialization import deserialize, serializable, serialize

from ..residual import Residual


[docs] @serializable("bayesflow.networks") class MLP(keras.Sequential): """ Implements a simple configurable MLP with optional residual connections and dropout. If used in conjunction with a coupling net, a diffusion model, or a flow matching model, it assumes that the input and conditions are already concatenated (i.e., this is a single-input model). """ def __init__( self, widths: Sequence[int] = (256, 256), *, activation: str | Callable[[], keras.Layer] = "mish", kernel_initializer: str | keras.Initializer = "he_normal", residual: bool = True, dropout: Literal[0, None] | float = 0.05, norm: Literal["batch", "layer"] | keras.Layer = None, spectral_normalization: bool = False, **kwargs, ): """ Implements a flexible multi-layer perceptron (MLP) with optional residual connections, dropout, and spectral normalization. This MLP can be used as a general-purpose feature extractor or function approximator, supporting configurable depth, width, activation functions, and weight initializations. If `residual` is enabled, each layer includes a skip connection for improved gradient flow. The model also supports dropout for regularization and spectral normalization for stability in learning smooth functions. Parameters ---------- widths : Sequence[int], optional Defines the number of hidden units per layer, as well as the number of layers to be used. activation : str, optional Activation function applied in the hidden layers, such as "mish". Default is "mish". kernel_initializer : str, optional Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal". residual : bool, optional Whether to use residual connections for improved training stability. Default is False. dropout : float or None, optional Dropout rate applied within the MLP layers for regularization. Default is 0.05. norm: str, optional spectral_normalization : bool, optional Whether to apply spectral normalization to stabilize training. Default is False. **kwargs Additional keyword arguments passed to the Keras layer initialization. """ self.widths = list(widths) self.activation = activation self.kernel_initializer = kernel_initializer self.residual = residual self.dropout = dropout self.norm = norm self.spectral_normalization = spectral_normalization layers = [] for width in widths: layer = self._make_layer( width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization ) layers.append(layer) super().__init__(layers, **sequential_kwargs(kwargs))
[docs] def build(self, input_shape=None): if self.built: # building when the network is already built can cause issues with serialization # see https://github.com/keras-team/keras/issues/21147 return # we only care about the last dimension, and using ... signifies to keras.Sequential # that any number of batch dimensions is valid (which is what we want for all sublayers) # we also have to avoid calling super().build() because this causes # shape errors when building on non-sets but doing inference on sets # this is a work-around for https://github.com/keras-team/keras/issues/21158 input_shape = (..., input_shape[-1]) for layer in self._layers: layer.build(input_shape) input_shape = layer.compute_output_shape(input_shape)
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))
[docs] def get_config(self): base_config = super().get_config() base_config = sequential_kwargs(base_config) config = { "widths": self.widths, "activation": self.activation, "kernel_initializer": self.kernel_initializer, "residual": self.residual, "dropout": self.dropout, "norm": self.norm, "spectral_normalization": self.spectral_normalization, } return base_config | serialize(config)
@staticmethod def _make_layer(width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization): layers = [] dense = keras.layers.Dense(width, kernel_initializer=kernel_initializer) if spectral_normalization: dense = keras.layers.SpectralNormalization(dense) layers.append(dense) if dropout is not None and dropout > 0: layers.append(keras.layers.Dropout(dropout)) activation = keras.activations.get(activation) if not isinstance(activation, keras.Layer): activation = keras.layers.Activation(activation) layers.append(activation) if norm == "batch": layers.append(keras.layers.BatchNormalization()) elif norm == "layer": layers.append(keras.layers.LayerNormalization()) elif isinstance(norm, str): raise ValueError(f"Unknown normalization strategy: {norm!r}.") elif isinstance(norm, keras.Layer): layers.append(norm) elif norm is None: pass else: raise TypeError(f"Cannot infer norm from {norm!r} of type {type(norm)}.") if residual: return Residual(*layers) return keras.Sequential(layers)