Source code for bayesflow.networks.transformers.set_transformer

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import check_lengths_same
from bayesflow.utils.decorators import sanitize_input_shape

from ..summary_network import SummaryNetwork

from .sab import SetAttentionBlock
from .isab import InducedSetAttentionBlock
from .pma import PoolingByMultiHeadAttention


[docs] @serializable(package="bayesflow.networks") class SetTransformer(SummaryNetwork): """Implements the set transformer architecture from [1] which ultimately represents a learnable permutation-invariant function. Designed to naturally model interactions in the input set, which may be hard to capture with the simpler ``DeepSet`` architecture. [1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019). Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR. Note: Currently works only on 3D inputs but can easily be expanded by using ``keras.layers.TimeDistributed``. """ def __init__( self, summary_dim: int = 16, embed_dims: tuple = (64, 64), num_heads: tuple = (4, 4), mlp_depths: tuple = (2, 2), mlp_widths: tuple = (128, 128), num_seeds: int = 1, dropout: float = 0.05, mlp_activation: str = "gelu", kernel_initializer: str = "he_normal", use_bias: bool = True, layer_norm: bool = True, num_inducing_points: int = None, seed_dim: int = None, **kwargs, ): """ Creates a many-to-one permutation-invariant encoder, typically used as a summary net for embedding set-based, (i.e., exchangeable or IID) data. Use a TimeSeriesTransformer or a FusionTransformer for non-IID data. The number of multi-head attention block is inferred from the length of `embed_dims` tuple. Parameters ---------- summary_dim : int, optional (default - 16) Dimensionality of the final summary output. embed_dims : tuple of int, optional (default - (64, 64)) Dimensions of the keys, values, and queries for each attention block. num_heads : tuple of int, optional (default - (4, 4)) Number of attention heads for each embedding dimension. mlp_depths : tuple of int, optional (default - (2, 2)) Depth of the multi-layer perceptron (MLP) blocks for each component. mlp_widths : tuple of int, optional (default - (128, 128)) Width of each MLP layer in each block for each component. num_seeds : int, optional (default - 1) Number of seeds to use for embedding. dropout : float, optional (default - 0.05) Dropout rate applied to the attention and MLP layers. If set to None, no dropout is applied. mlp_activation : str, optional (default - 'gelu') Activation function used in the dense layers. Common choices include "relu", "elu", and "gelu". kernel_initializer : str, optional (default - 'he_normal') Initializer for the kernel weights matrix. Common choices include "glorot_uniform", "he_normal", etc. use_bias : bool, optional (default - True) Whether to include a bias term in the dense layers. layer_norm : bool, optional (default - True) Whether to apply layer normalization after the attention and MLP layers. num_inducing_points : int or None, optional (default - None) Number of inducing points used, if applicable. If set to None, this option is disabled. seed_dim : int or None, optional (default - None) Dimensionality of the seed embeddings. If None, it defaults to `summary_dim`. **kwargs : dict Additional keyword arguments passed to the base layer. """ super().__init__(**kwargs) check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths) num_attention_layers = len(embed_dims) # Construct a series of set-attention blocks self.attention_blocks = keras.Sequential() global_attention_settings = dict( dropout=dropout, mlp_activation=mlp_activation, kernel_initializer=kernel_initializer, use_bias=use_bias, layer_norm=layer_norm, ) for i in range(num_attention_layers): layer_attention_settings = dict( num_heads=num_heads[i], embed_dim=embed_dims[i], mlp_depth=mlp_depths[i], mlp_width=mlp_widths[i], ) if num_inducing_points is None: block = SetAttentionBlock(**(global_attention_settings | layer_attention_settings)) else: isab_settings = dict(num_inducing_points=num_inducing_points) block = InducedSetAttentionBlock( **(global_attention_settings | layer_attention_settings | isab_settings) ) self.attention_blocks.add(block) # Pooling will be applied as a final step to the abstract representations obtained from set attention pooling_settings = dict( num_heads=num_heads[-1], embed_dim=embed_dims[-1], mlp_depth=mlp_depths[-1], mlp_width=mlp_widths[-1], seed_dim=seed_dim, num_seeds=num_seeds, ) self.pooling_by_attention = PoolingByMultiHeadAttention(**(global_attention_settings | pooling_settings)) self.output_projector = keras.layers.Dense(summary_dim) self.summary_dim = summary_dim
[docs] def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: """Compresses the input sequence into a summary vector of size `summary_dim`. Parameters ---------- input_set : Tensor (e.g., np.ndarray, tf.Tensor, ...) Input of shape (batch_size, set_size, input_dim) training : boolean, optional (default - False) Passed to the optional internal dropout and spectral normalization layers to distinguish between train and test time behavior. **kwargs : dict, optional (default - {}) Additional keyword arguments passed to the internal attention layer, such as ``attention_mask`` or ``return_attention_scores`` Returns ------- out : Tensor Output of shape (batch_size, set_size, output_dim) """ summary = self.attention_blocks(input_set, training=training, **kwargs) summary = self.pooling_by_attention(summary, training=training, **kwargs) summary = self.output_projector(summary) return summary
[docs] @sanitize_input_shape def build(self, input_shape): super().build(input_shape) self.call(keras.ops.zeros(input_shape))