Source code for bayesflow.networks.deep_set.deep_set

from collections.abc import Sequence

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
from bayesflow.utils.decorators import sanitize_input_shape

from .equivariant_module import EquivariantModule
from .invariant_module import InvariantModule
from ..summary_network import SummaryNetwork


[docs] @serializable(package="bayesflow.networks") class DeepSet(SummaryNetwork): """Implements a deep set encoder introduced in [1] for learning permutation-invariant representations of set-based data, as generated by exchangeable models. [1] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., & Smola, A. J. (2017). Deep sets. Advances in neural information processing systems, 30. """ def __init__( self, summary_dim: int = 16, depth: int = 2, inner_pooling: str = "mean", output_pooling: str = "mean", mlp_widths_equivariant: Sequence[int] = (64, 64), mlp_widths_invariant_inner: Sequence[int] = (64, 64), mlp_widths_invariant_outer: Sequence[int] = (64, 64), mlp_widths_invariant_last: Sequence[int] = (64, 64), activation: str = "gelu", kernel_initializer: str = "he_normal", dropout: int | float | None = 0.05, spectral_normalization: bool = False, **kwargs, ): """ Initializes a fully customizable deep learning model for learning permutation-invariant representations of sets (i.e., exchangeable or IID data). Do not use this model for non-IID data (e.g., time series). Important: Prefer a SetTransformer to a DeepSet, especially is the simulation budget is high. The model consists of multiple stacked equivariant transformation modules followed by an invariant pooling module to produce a compact set representation. The equivariant layers perform many-to-many transformations, preserving structural information, while the final invariant module aggregates the set into a lower-dimensional summary. The model supports various activation functions, kernel initializations, and optional spectral normalization for stability. Pooling mechanisms can be specified for both intermediate and final aggregation steps. Parameters ---------- summary_dim : int, optional Dimensionality of the final learned summary statistics. Default is 16. depth : int, optional Number of stacked equivariant modules. Default is 2. inner_pooling : str, optional Type of pooling operation applied within equivariant modules, such as "mean". Default is "mean". output_pooling : str, optional Type of pooling operation applied in the final invariant module, such as "mean". Default is "mean". mlp_widths_equivariant : Sequence[int], optional Widths of the MLP layers inside the equivariant modules. Default is (64, 64). mlp_widths_invariant_inner : Sequence[int], optional Widths of the inner MLP layers within the invariant module. Default is (64, 64). mlp_widths_invariant_outer : Sequence[int], optional Widths of the outer MLP layers within the invariant module. Default is (64, 64). mlp_widths_invariant_last : Sequence[int], optional Widths of the MLP layers in the final invariant transformation. Default is (64, 64). activation : str, optional Activation function used throughout the network, such as "gelu". Default is "gelu". kernel_initializer : str, optional Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal". dropout : int, float, or None, optional Dropout rate applied within MLP layers. Default is 0.05. spectral_normalization : bool, optional Whether to apply spectral normalization to stabilize training. Default is False. **kwargs Additional keyword arguments passed to the equivariant and invariant modules. """ super().__init__(**kwargs) # Stack of equivariant modules for a many-to-many learnable transformation self.equivariant_modules = [] for _ in range(depth): equivariant_module = EquivariantModule( mlp_widths_equivariant=mlp_widths_equivariant, mlp_widths_invariant_inner=mlp_widths_invariant_inner, mlp_widths_invariant_outer=mlp_widths_invariant_outer, activation=activation, kernel_initializer=kernel_initializer, spectral_normalization=spectral_normalization, dropout=dropout, pooling=inner_pooling, **filter_kwargs(kwargs, EquivariantModule), ) self.equivariant_modules.append(equivariant_module) # Invariant module for a many-to-one transformation self.invariant_module = InvariantModule( mlp_widths_inner=mlp_widths_invariant_last, mlp_widths_outer=mlp_widths_invariant_last, activation=activation, kernel_initializer=kernel_initializer, dropout=dropout, pooling=output_pooling, spectral_normalization=spectral_normalization, **filter_kwargs(kwargs, InvariantModule), ) # Output linear layer to project set representation down to "summary_dim" learned summary statistics self.output_projector = keras.layers.Dense(summary_dim, activation="linear") self.summary_dim = summary_dim
[docs] @sanitize_input_shape def build(self, input_shape): super().build(input_shape) self.call(keras.ops.zeros(input_shape))
[docs] def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: """ Performs the forward pass of a hierarchical deep invariant transformation. This function applies a sequence of equivariant transformations to the input tensor, preserving structural relationships while refining representations. After passing through the equivariant modules, the data is processed by an invariant transformation, which aggregates information into a lower-dimensional representation. The final output is projected to the specified summary dimension using a linear layer. Parameters ---------- x : Tensor Input tensor representing a set or collection of elements to be transformed. training : bool, optional Whether the model is in training mode, affecting layers like dropout. Default is False. **kwargs Additional keyword arguments passed to the transformation layers. Returns ------- output : Tensor Transformed tensor with a reduced dimensionality, representing the learned summary of the input set. """ for em in self.equivariant_modules: x = em(x, training=training) x = self.invariant_module(x, training=training) return self.output_projector(x)