Source code for bayesflow.networks.summary.deep_set.deep_set

from collections.abc import Sequence

import keras

from bayesflow.types import Tensor
from bayesflow.utils.serialization import serializable

from .equivariant_layer import EquivariantLayer

from ...summary import SummaryNetwork
from ..transformers.attention import PoolingByMultiHeadAttention


[docs] @serializable("bayesflow.networks") class DeepSet(SummaryNetwork): """(SN) Implements a deep set encoder introduced in [1, 2] for learning permutation-invariant representations of set-based data, as generated by exchangeable models. It applies many new features not present in the original papers, such as RMS norm, pooling by attention, and residual connections for expressiveness and stability. A stack of equivariant layers (each injecting a pooled invariant summary back into every set element via a Pre-LN residual block) is followed by a PoolingByMultiHeadAttention (PMA) layer and a linear output projection. [1] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., & Smola, A. J. (2017). Deep sets. Advances in neural information processing systems, 30. [2] Bloem-Reddy, B., & Teh, Y. W. (2020). Probabilistic Symmetries and Invariant Neural Networks. J. Mach. Learn. Res., 21, 90-1. https://www.jmlr.org/papers/volume21/19-322/19-322.pdf Parameters ---------- summary_dim : int, optional Dimensionality of the final learned summary statistics. Default is 16. embed_dim : int, optional Working dimensionality shared across all equivariant layers. Default is 64. depth : int, optional Number of stacked equivariant modules. Default is 2. mlp_widths : Sequence[int], optional Hidden layer widths for the MLPs inside each equivariant layer. The output of each MLP is always ``embed_dim``. Default is ``(64,)``. inner_pooling : str, optional Pooling used inside the equivariant modules. Default is ``"mean"``. num_heads : int, optional Number of attention heads in the attention pooling block. Default is 4. num_seeds : int, optional Number of seed vectors in the attention pooling block. Default is 4. seed_dim : int or None, optional Dimensionality of attention pooling seed vectors. If None, defaults to ``embed_dim``. Default is None. expansion_factor : float, optional FFN width multiplier in the PMA block. Default is 4.0. glu_variant : str, optional GLU activation variant for the PMA FFN. Default is ``"swiglu"``. use_bias : bool, optional Whether to include bias terms in the PMA dense layers. Default is False. layer_norm : bool, optional Whether to apply Pre-LN RMSNorm in equivariant and invariant layers. Default is True. activation : str, optional Activation function used throughout. Default is ``"silu"``. kernel_initializer : str, optional Weight initializer for Dense projections. Default is ``"he_normal"``. dropout : float or None, optional Dropout rate. Default is 0.05. **kwargs Additional keyword arguments passed to the base class. """ def __init__( self, summary_dim: int = 16, embed_dim: int = 64, depth: int = 2, mlp_widths: Sequence[int] = (64,), inner_pooling: str = "mean", num_heads: int = 4, num_seeds: int = 4, seed_dim: int = None, expansion_factor: float = 4.0, glu_variant: str = "swiglu", use_bias: bool = False, layer_norm: bool = True, activation: str = "silu", kernel_initializer: str = "he_normal", dropout: float | None = 0.05, **kwargs, ): super().__init__(**kwargs) self.summary_dim = summary_dim self.equivariant_modules = [ EquivariantLayer( embed_dim=embed_dim, mlp_widths=mlp_widths, pooling=inner_pooling, activation=activation, kernel_initializer=kernel_initializer, dropout=dropout, layer_norm=layer_norm, ) for _ in range(depth) ] self.pooling_by_attention = PoolingByMultiHeadAttention( num_heads=num_heads, embed_dim=mlp_widths[-1], num_seeds=num_seeds, seed_dim=seed_dim, dropout=dropout, kernel_initializer=kernel_initializer, expansion_factor=expansion_factor, glu_variant=glu_variant, use_bias=use_bias, layer_norm=layer_norm, ) self.output_projector = keras.layers.Dense(units=summary_dim)
[docs] def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: for em in self.equivariant_modules: x = em(x, training=training) x = self.pooling_by_attention(x, training=training) return self.output_projector(x)