from collections.abc import Sequence
import keras
from bayesflow.types import Tensor
from bayesflow.utils.serialization import serializable
from .equivariant_layer import EquivariantLayer
from ...summary import SummaryNetwork
from ..transformers.attention import PoolingByMultiHeadAttention
[docs]
@serializable("bayesflow.networks")
class DeepSet(SummaryNetwork):
"""(SN) Implements a deep set encoder introduced in [1, 2] for learning permutation-invariant
representations of set-based data, as generated by exchangeable models. It applies many new
features not present in the original papers, such as RMS norm, pooling by attention, and residual
connections for expressiveness and stability.
A stack of equivariant layers (each injecting a pooled invariant summary back into
every set element via a Pre-LN residual block) is followed by a
PoolingByMultiHeadAttention (PMA) layer and a linear output projection.
[1] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., & Smola, A. J.
(2017). Deep sets. Advances in neural information processing systems, 30.
[2] Bloem-Reddy, B., & Teh, Y. W. (2020). Probabilistic Symmetries and Invariant Neural Networks.
J. Mach. Learn. Res., 21, 90-1. https://www.jmlr.org/papers/volume21/19-322/19-322.pdf
Parameters
----------
summary_dim : int, optional
Dimensionality of the final learned summary statistics. Default is 16.
embed_dim : int, optional
Working dimensionality shared across all equivariant layers.
Default is 64.
depth : int, optional
Number of stacked equivariant modules. Default is 2.
mlp_widths : Sequence[int], optional
Hidden layer widths for the MLPs inside each equivariant layer.
The output of each MLP is always ``embed_dim``. Default is ``(64,)``.
inner_pooling : str, optional
Pooling used inside the equivariant modules. Default is ``"mean"``.
num_heads : int, optional
Number of attention heads in the attention pooling block. Default is 4.
num_seeds : int, optional
Number of seed vectors in the attention pooling block. Default is 4.
seed_dim : int or None, optional
Dimensionality of attention pooling seed vectors. If None, defaults to ``embed_dim``.
Default is None.
expansion_factor : float, optional
FFN width multiplier in the PMA block. Default is 4.0.
glu_variant : str, optional
GLU activation variant for the PMA FFN. Default is ``"swiglu"``.
use_bias : bool, optional
Whether to include bias terms in the PMA dense layers. Default is False.
layer_norm : bool, optional
Whether to apply Pre-LN RMSNorm in equivariant and invariant layers. Default is True.
activation : str, optional
Activation function used throughout. Default is ``"silu"``.
kernel_initializer : str, optional
Weight initializer for Dense projections. Default is ``"he_normal"``.
dropout : float or None, optional
Dropout rate. Default is 0.05.
**kwargs
Additional keyword arguments passed to the base class.
"""
def __init__(
self,
summary_dim: int = 16,
embed_dim: int = 64,
depth: int = 2,
mlp_widths: Sequence[int] = (64,),
inner_pooling: str = "mean",
num_heads: int = 4,
num_seeds: int = 4,
seed_dim: int = None,
expansion_factor: float = 4.0,
glu_variant: str = "swiglu",
use_bias: bool = False,
layer_norm: bool = True,
activation: str = "silu",
kernel_initializer: str = "he_normal",
dropout: float | None = 0.05,
**kwargs,
):
super().__init__(**kwargs)
self.summary_dim = summary_dim
self.equivariant_modules = [
EquivariantLayer(
embed_dim=embed_dim,
mlp_widths=mlp_widths,
pooling=inner_pooling,
activation=activation,
kernel_initializer=kernel_initializer,
dropout=dropout,
layer_norm=layer_norm,
)
for _ in range(depth)
]
self.pooling_by_attention = PoolingByMultiHeadAttention(
num_heads=num_heads,
embed_dim=mlp_widths[-1],
num_seeds=num_seeds,
seed_dim=seed_dim,
dropout=dropout,
kernel_initializer=kernel_initializer,
expansion_factor=expansion_factor,
glu_variant=glu_variant,
use_bias=use_bias,
layer_norm=layer_norm,
)
self.output_projector = keras.layers.Dense(units=summary_dim)
[docs]
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
for em in self.equivariant_modules:
x = em(x, training=training)
x = self.pooling_by_attention(x, training=training)
return self.output_projector(x)