Source code for bayesflow.networks.mlp.time_mlp

from typing import Literal, Callable, Sequence

import keras

from bayesflow.networks.embeddings import FourierEmbedding
from bayesflow.networks.residual import ConditionalResidual
from bayesflow.types import Tensor
from bayesflow.utils import concatenate_valid, layer_kwargs
from bayesflow.utils.serialization import serialize, serializable, deserialize


[docs] @serializable("bayesflow.networks") class TimeMLP(keras.Layer): """ Implements a time-conditioned multi-layer perceptron (MLP). The model processes three separate inputs: the state variable `x`, a scalar or vector-valued time input `t`, and a conditioning variable `conditions`. The input and condition are projected into a shared feature space, merged, and passed through a deep residual MLP. A learned time embedding is injected via FiLM at every hidden layer. If `residual` is enabled, each layer includes a skip connection for improved gradient flow. The model also supports dropout for regularization and spectral normalization for stability in learning smooth functions. """ def __init__( self, widths: Sequence[int] = (256, 256), *, time_embedding_dim: int = 32, time_emb: keras.Layer | None = None, activation: str | Callable[[], keras.Layer] = "mish", kernel_initializer: str | keras.Initializer = "he_normal", residual: bool = True, dropout: Literal[0, None] | float = 0.05, norm: Literal["batch", "layer"] | keras.Layer = "layer", spectral_normalization: bool = False, merge: Literal["add", "concat"] = "concat", **kwargs, ): """ Implements a time-conditioned multi-layer perceptron (MLP). Parameters ---------- widths : Sequence[int], optional Defines the number of hidden units per layer, as well as the number of layers to be used. time_emb_dim : int, optional Dimensionality of the learned time embedding. Default is 32. If set to 1, no embedding is applied and time is used directly. time_emb : keras.layers.Layer or None, optional Custom time embedding layer. If None, a random Fourier feature embedding is used. activation : str, optional Activation function applied in the hidden layers, such as "mish". Default is "mish". kernel_initializer : str, optional Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal". residual : bool, optional Whether to use residual connections for improved training stability. Default is True. dropout : float or None, optional Dropout rate applied within the MLP layers for regularization. Default is 0.05. norm : str or keras.layers.Layer or None, optional Normalization applied after each hidden layer ("batch", "layer", or None). Default is "layer". spectral_normalization : bool, optional Whether to apply spectral normalization to dense layers. Default is False. merge : str, optional Method to merge input and condition if available ("add" or "concat"). Default is "concat". **kwargs Additional keyword arguments passed to `keras.Model`. """ super().__init__(**layer_kwargs(kwargs)) self.widths = list(widths) self.time_embedding_dim = int(time_embedding_dim) self.activation = activation self.kernel_initializer = kernel_initializer self.residual = residual self.dropout = dropout self.norm = norm self.spectral_normalization = spectral_normalization self.merge = merge if len(self.widths) == 0: raise ValueError("TimeMLP requires at least one hidden width.") # Time embedding if time_emb is None: if self.time_embedding_dim == 1: self.time_emb = keras.layers.Identity() else: self.time_emb = FourierEmbedding( embed_dim=self.time_embedding_dim, scale=kwargs.pop("fourier_scale", 30.0), include_identity=True, ) else: self.time_emb = time_emb # Projections for x and conditions into a shared space self.x_proj = keras.layers.Dense(self.widths[0], kernel_initializer=self.kernel_initializer, name="x_proj") self.c_proj = None if merge != "add" and merge != "concat": raise ValueError(f"Unknown merge mode: {merge!r} (expected 'add' or 'concat').") self.merge_proj = None act = keras.activations.get(activation) if not isinstance(act, keras.Layer): act = keras.layers.Activation(act) self.act = act self.blocks = [ ConditionalResidual( w, activation=activation, kernel_initializer=kernel_initializer, residual=residual, dropout=dropout, norm=norm, spectral_normalization=spectral_normalization, **kwargs, ) for i, w in enumerate(self.widths) ]
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))
[docs] def get_config(self): base = super().get_config() base = layer_kwargs(base) cfg = { "widths": self.widths, "time_embedding_dim": self.time_embedding_dim, "time_emb": self.time_emb, "activation": self.activation, "kernel_initializer": self.kernel_initializer, "residual": self.residual, "dropout": self.dropout, "norm": self.norm, "spectral_normalization": self.spectral_normalization, "merge": self.merge, } return base | serialize(cfg)
[docs] def build(self, input_shape): if self.built: return x_shape, t_shape, conditions_shape = input_shape self.time_emb.build(t_shape) t_emb_shape = self.time_emb.compute_output_shape(t_shape) # Merge / input pathway self.x_proj.build(x_shape) h_shape = self.x_proj.compute_output_shape(x_shape) if conditions_shape is not None: self.c_proj = keras.layers.Dense(self.widths[0], kernel_initializer=self.kernel_initializer, name="c_proj") self.c_proj.build(conditions_shape) if self.merge == "concat": merge_shape = list(h_shape) merge_shape[-1] = merge_shape[-1] + self.widths[0] merge_shape = tuple(merge_shape) else: merge_shape = h_shape self.merge_proj = keras.layers.Dense( self.widths[0], kernel_initializer=self.kernel_initializer, name="merge_proj" ) self.merge_proj.build(merge_shape) # Conditional residual blocks for block in self.blocks: block.build((h_shape, t_emb_shape)) h_shape = block.compute_output_shape((h_shape, t_emb_shape)) super().build(input_shape)
[docs] def compute_output_shape(self, input_shape): x_shape, t_shape, conditions_shape = input_shape h_shape = self.x_proj.compute_output_shape(x_shape) t_emb_shape = self.time_emb.compute_output_shape(t_shape) for block in self.blocks: h_shape = block.compute_output_shape((h_shape, t_emb_shape)) return h_shape
[docs] def call( self, inputs: tuple[Tensor, Tensor, Tensor] | tuple[Tensor, Tensor, None], training: bool = None, mask=None, ) -> Tensor: x, t, conditions = inputs h = self.x_proj(x) if conditions is not None: hc = self.c_proj(conditions) if self.merge == "concat": h = concatenate_valid([h, hc], axis=-1) else: h = h + hc h = self.merge_proj(self.act(h)) h = self.act(h) t_emb = self.time_emb(t) for block in self.blocks: h = block((h, t_emb), training=training) return h