from typing import Literal, Callable, Sequence
import keras
from bayesflow.types import Tensor
from bayesflow.utils import layer_kwargs
from bayesflow.utils.serialization import deserialize, serializable, serialize
from ...helpers import DenseBlock
[docs]
@serializable("bayesflow.networks")
class MLP(keras.Layer):
"""Multi-layer perceptron built from :class:`DenseBlock` layers.
Accepts a single tensor input. When used inside a coupling layer the
caller (e.g. :class:`SingleCoupling`) concatenates any non-time
conditions onto the input before passing it here.
Parameters
----------
widths : Sequence[int], optional
Number of hidden units per layer. Default is ``(256, 256)``.
activation : str or callable, optional
Activation function for hidden layers. Default is ``"mish"``.
kernel_initializer : str or keras.Initializer, optional
Weight initialization strategy. Default is ``"he_normal"``.
residual : bool, optional
Whether to use residual (skip) connections. Default is ``True``.
dropout : float or None, optional
Dropout rate for regularization. Default is ``0.05``.
norm : ``"batch"``, ``"layer"``, ``"rms"``, keras.Layer, or None, optional
Normalization applied after each hidden layer. Default is ``None``.
**kwargs
Additional keyword arguments passed to ``keras.Layer``.
"""
def __init__(
self,
widths: Sequence[int] = (256, 256),
*,
activation: str | Callable[[], keras.Layer] = "mish",
kernel_initializer: str | keras.Initializer = "he_normal",
residual: bool = True,
dropout: Literal[0, None] | float = 0.05,
norm: Literal["batch", "layer", "rms"] | keras.Layer = None,
**kwargs,
):
super().__init__(**layer_kwargs(kwargs))
if len(widths) == 0:
raise ValueError("MLP requires at least one hidden width.")
self.widths = widths
self.activation = activation
self.kernel_initializer = kernel_initializer
self.residual = residual
self.dropout = dropout
self.norm = norm
# Hidden blocks
self.blocks = [
DenseBlock(
width=width,
activation=activation,
kernel_initializer=kernel_initializer,
residual=residual,
dropout=dropout,
norm=norm,
)
for width in self.widths
]
[docs]
def call(self, x: Tensor, training: bool = None) -> Tensor:
h = x
for block in self.blocks:
h = block(h, training=training)
return h
[docs]
def build(self, input_shape):
h_shape = input_shape
for block in self.blocks:
block.build(h_shape)
h_shape = block.compute_output_shape(h_shape)
[docs]
def compute_output_shape(self, input_shape):
h_shape = input_shape
for block in self.blocks:
h_shape = block.compute_output_shape(h_shape)
return h_shape
[docs]
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))
[docs]
def get_config(self):
base_config = super().get_config()
base_config = layer_kwargs(base_config)
config = {
"widths": self.widths,
"activation": self.activation,
"kernel_initializer": self.kernel_initializer,
"residual": self.residual,
"dropout": self.dropout,
"norm": self.norm,
}
return base_config | serialize(config)