Source code for bayesflow.networks.subnets.mlp.time_mlp

from typing import Literal, Callable, Sequence

import keras

from bayesflow.types import Tensor
from bayesflow.utils import layer_kwargs
from bayesflow.utils.serialization import serialize, serializable, deserialize

from ...helpers import FourierEmbedding
from ...helpers import ConditionalDenseBlock


[docs] @serializable("bayesflow.networks") class TimeMLP(keras.Layer): """Time-conditioned multi-layer perceptron with FiLM modulation. Processes three inputs: a state variable ``x``, a scalar or vector-valued time ``t``, and an optional conditioning variable ``conditions``. The input and conditions are projected into a shared feature space, merged, and passed through residual blocks. A learned time embedding is injected via FiLM at every hidden layer. Parameters ---------- widths : Sequence[int], optional Number of hidden units per layer. Default is ``(256, 256)``. time_embedding_dim : int, optional Dimensionality of the learned time embedding. Default is ``32``. Set to ``1`` to use time directly without embedding. time_emb : keras.Layer or None, optional Custom time embedding layer. If ``None``, uses random Fourier features. fourier_scale : float, optional Frequency scaling for the default Fourier embedding. Default is ``30.0``. Ignored when *time_emb* is provided. activation : str or callable, optional Activation function for hidden layers. Default is ``"mish"``. kernel_initializer : str or keras.Initializer, optional Weight initialization strategy. Default is ``"he_normal"``. residual : bool, optional Whether to use residual connections. Default is ``True``. dropout : float or None, optional Dropout rate for regularization. Default is ``0.05``. norm : ``"batch"``, ``"layer"``, ``"rms"``, keras.Layer, or None, optional Normalization applied after each hidden layer. Default is ``"layer"``. merge : ``"add"`` or ``"concat"``, optional How to merge input and conditions (``"add"`` or ``"concat"``). Default is ``"concat"``. film_use_gamma : bool, optional Whether film uses a learnable gamma. Default is ``False``. **kwargs Additional keyword arguments passed to ``keras.Layer``. """ def __init__( self, widths: Sequence[int] = (256, 256), *, time_embedding_dim: int = 32, time_emb: keras.Layer | None = None, fourier_scale: float = 30.0, activation: str | Callable[[], keras.Layer] = "mish", kernel_initializer: str | keras.Initializer = "he_normal", residual: bool = True, dropout: Literal[0, None] | float = 0.05, norm: Literal["batch", "layer", "rms"] | keras.Layer = "layer", merge: Literal["add", "concat"] = "concat", film_use_gamma: bool = False, **kwargs, ): super().__init__(**layer_kwargs(kwargs)) if len(widths) == 0: raise ValueError("TimeMLP requires at least one hidden width.") if merge not in ("add", "concat"): raise ValueError(f"Unknown merge mode: {merge!r} (expected 'add' or 'concat').") self.widths = widths self.time_embedding_dim = time_embedding_dim self.fourier_scale = fourier_scale self.activation = activation self.kernel_initializer = kernel_initializer self.residual = residual self.dropout = dropout self.norm = norm self.merge = merge self.film_use_gamma = film_use_gamma # Time embedding if time_emb is None: if self.time_embedding_dim == 1: self.time_emb = keras.layers.Identity() else: self.time_emb = FourierEmbedding( embed_dim=self.time_embedding_dim, scale=self.fourier_scale, include_identity=True, ) else: self.time_emb = time_emb # Projections for x and conditions into a shared space self.x_proj = keras.layers.Dense(self.widths[0], kernel_initializer=self.kernel_initializer, name="x_proj") self.c_proj = None self.merge_proj = None activation = keras.activations.get(activation) if not isinstance(activation, keras.Layer): activation = keras.layers.Activation(activation) self.activation = activation # Time-conditional blocks using film self.blocks = [ ConditionalDenseBlock( width=width, activation=activation, kernel_initializer=kernel_initializer, residual=residual, dropout=dropout, norm=norm, film_use_gamma=film_use_gamma, ) for width in self.widths ]
[docs] def call( self, inputs: tuple[Tensor, Tensor, Tensor | None], training: bool = None, ) -> Tensor: x, t, conditions = inputs h = self.x_proj(x) if conditions is not None and self.c_proj is not None: hc = self.c_proj(conditions) if self.merge == "concat": h = keras.ops.concatenate([h, hc], axis=-1) else: h = h + hc h = self.merge_proj(self.activation(h)) h = self.activation(h) t_emb = self.time_emb(t) for block in self.blocks: h = block((h, t_emb), training=training) return h
[docs] def build(self, input_shape): if self.built: return x_shape, t_shape, conditions_shape = input_shape # Time embedding self.time_emb.build(t_shape) t_emb_shape = self.time_emb.compute_output_shape(t_shape) # Input projection self.x_proj.build(x_shape) h_shape = self.x_proj.compute_output_shape(x_shape) # Condition projection and merge if conditions_shape is not None: self.c_proj = keras.layers.Dense(self.widths[0], kernel_initializer=self.kernel_initializer, name="c_proj") self.c_proj.build(conditions_shape) if self.merge == "concat": merge_shape = list(h_shape) merge_shape[-1] = merge_shape[-1] + self.widths[0] merge_shape = tuple(merge_shape) else: merge_shape = h_shape self.merge_proj = keras.layers.Dense( self.widths[0], kernel_initializer=self.kernel_initializer, name="merge_proj" ) self.merge_proj.build(merge_shape) # Time-conditional blocks for block in self.blocks: block.build((h_shape, t_emb_shape)) h_shape = block.compute_output_shape((h_shape, t_emb_shape))
[docs] def compute_output_shape(self, input_shape): x_shape, t_shape, conditions_shape = input_shape h_shape = self.x_proj.compute_output_shape(x_shape) t_emb_shape = self.time_emb.compute_output_shape(t_shape) for block in self.blocks: h_shape = block.compute_output_shape((h_shape, t_emb_shape)) return h_shape
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))
[docs] def get_config(self): base_config = super().get_config() base_config = layer_kwargs(base_config) config = { "widths": self.widths, "time_embedding_dim": self.time_embedding_dim, "time_emb": self.time_emb, "fourier_scale": self.fourier_scale, "activation": self.activation, "kernel_initializer": self.kernel_initializer, "residual": self.residual, "dropout": self.dropout, "norm": self.norm, "merge": self.merge, "film_use_gamma": self.film_use_gamma, } return base_config | serialize(config)