from collections.abc import Sequence
import keras
from keras.saving import register_keras_serializable as serializable
from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
from bayesflow.utils.decorators import sanitize_input_shape
from .equivariant_module import EquivariantModule
from .invariant_module import InvariantModule
from ..summary_network import SummaryNetwork
[docs]
@serializable(package="bayesflow.networks")
class DeepSet(SummaryNetwork):
"""Implements a deep set encoder introduced in [1] for learning permutation-invariant representations of
set-based data, as generated by exchangeable models.
[1] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., & Smola, A. J. (2017).
Deep sets. Advances in neural information processing systems, 30.
"""
def __init__(
self,
summary_dim: int = 16,
depth: int = 2,
inner_pooling: str = "mean",
output_pooling: str = "mean",
mlp_widths_equivariant: Sequence[int] = (64, 64),
mlp_widths_invariant_inner: Sequence[int] = (64, 64),
mlp_widths_invariant_outer: Sequence[int] = (64, 64),
mlp_widths_invariant_last: Sequence[int] = (64, 64),
activation: str = "gelu",
kernel_initializer: str = "he_normal",
dropout: int | float | None = 0.05,
spectral_normalization: bool = False,
**kwargs,
):
"""
Initializes a fully customizable deep learning model for learning permutation-invariant representations of
sets (i.e., exchangeable or IID data). Do not use this model for non-IID data (e.g., time series).
Important: Prefer a SetTransformer to a DeepSet, especially is the simulation budget is high.
The model consists of multiple stacked equivariant transformation modules followed by an invariant pooling
module to produce a compact set representation.
The equivariant layers perform many-to-many transformations, preserving structural information, while
the final invariant module aggregates the set into a lower-dimensional summary.
The model supports various activation functions, kernel initializations, and optional spectral normalization
for stability. Pooling mechanisms can be specified for both intermediate and final aggregation steps.
Parameters
----------
summary_dim : int, optional
Dimensionality of the final learned summary statistics. Default is 16.
depth : int, optional
Number of stacked equivariant modules. Default is 2.
inner_pooling : str, optional
Type of pooling operation applied within equivariant modules, such as "mean".
Default is "mean".
output_pooling : str, optional
Type of pooling operation applied in the final invariant module, such as "mean".
Default is "mean".
mlp_widths_equivariant : Sequence[int], optional
Widths of the MLP layers inside the equivariant modules. Default is (64, 64).
mlp_widths_invariant_inner : Sequence[int], optional
Widths of the inner MLP layers within the invariant module. Default is (64, 64).
mlp_widths_invariant_outer : Sequence[int], optional
Widths of the outer MLP layers within the invariant module. Default is (64, 64).
mlp_widths_invariant_last : Sequence[int], optional
Widths of the MLP layers in the final invariant transformation. Default is (64, 64).
activation : str, optional
Activation function used throughout the network, such as "gelu". Default is "gelu".
kernel_initializer : str, optional
Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal".
dropout : int, float, or None, optional
Dropout rate applied within MLP layers. Default is 0.05.
spectral_normalization : bool, optional
Whether to apply spectral normalization to stabilize training. Default is False.
**kwargs
Additional keyword arguments passed to the equivariant and invariant modules.
"""
super().__init__(**kwargs)
# Stack of equivariant modules for a many-to-many learnable transformation
self.equivariant_modules = []
for _ in range(depth):
equivariant_module = EquivariantModule(
mlp_widths_equivariant=mlp_widths_equivariant,
mlp_widths_invariant_inner=mlp_widths_invariant_inner,
mlp_widths_invariant_outer=mlp_widths_invariant_outer,
activation=activation,
kernel_initializer=kernel_initializer,
spectral_normalization=spectral_normalization,
dropout=dropout,
pooling=inner_pooling,
**filter_kwargs(kwargs, EquivariantModule),
)
self.equivariant_modules.append(equivariant_module)
# Invariant module for a many-to-one transformation
self.invariant_module = InvariantModule(
mlp_widths_inner=mlp_widths_invariant_last,
mlp_widths_outer=mlp_widths_invariant_last,
activation=activation,
kernel_initializer=kernel_initializer,
dropout=dropout,
pooling=output_pooling,
spectral_normalization=spectral_normalization,
**filter_kwargs(kwargs, InvariantModule),
)
# Output linear layer to project set representation down to "summary_dim" learned summary statistics
self.output_projector = keras.layers.Dense(summary_dim, activation="linear")
self.summary_dim = summary_dim
[docs]
@sanitize_input_shape
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))
[docs]
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
"""
Performs the forward pass of a hierarchical deep invariant transformation.
This function applies a sequence of equivariant transformations to the input tensor,
preserving structural relationships while refining representations. After passing
through the equivariant modules, the data is processed by an invariant transformation,
which aggregates information into a lower-dimensional representation. The final output
is projected to the specified summary dimension using a linear layer.
Parameters
----------
x : Tensor
Input tensor representing a set or collection of elements to be transformed.
training : bool, optional
Whether the model is in training mode, affecting layers like dropout. Default is False.
**kwargs
Additional keyword arguments passed to the transformation layers.
Returns
-------
output : Tensor
Transformed tensor with a reduced dimensionality, representing the learned summary
of the input set.
"""
for em in self.equivariant_modules:
x = em(x, training=training)
x = self.invariant_module(x, training=training)
return self.output_projector(x)