Source code for bayesflow.networks.summary.transformers.set_transformer

import keras

from bayesflow.types import Tensor
from bayesflow.utils import check_lengths_same
from bayesflow.utils.serialization import serializable

from .transformer import Transformer
from .attention import SetAttention, InducedSetAttention, PoolingByMultiHeadAttention


[docs] @serializable("bayesflow.networks") class SetTransformer(Transformer): """(SN) Implements the set transformer architecture from [1] which ultimately represents a learnable permutation-invariant function. Designed to naturally model interactions in the input set, which may be hard to capture with the simpler ``DeepSet`` architecture. [1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019). Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR. Note: Currently works only on 3D inputs but can easily be expanded by using ``keras.layers.TimeDistributed``. """ def __init__( self, summary_dim: int = 16, embed_dims: tuple = (64, 64), num_heads: tuple = (4, 4), num_seeds: int = 4, dropout: float = 0.05, expansion_factor: float = 4.0, glu_variant: str = "swiglu", kernel_initializer: str = "glorot_uniform", use_bias: bool = False, layer_norm: bool = True, num_inducing_points: int = None, seed_dim: int = None, **kwargs, ): """ Creates a many-to-one permutation-invariant encoder, typically used as a summary network for embedding set-based (i.e., exchangeable or IID) data. Parameters ---------- summary_dim : int, optional Dimensionality of the final summary output, by default 16. embed_dims : tuple of int, optional Embedding dimensionality for each attention block, by default (64, 64). num_heads : tuple of int, optional Number of attention heads for each block, by default (4, 4). num_seeds : int, optional Number of seed vectors used for PMA pooling. Increase if performance appears subpar. By default 4. dropout : float, optional Dropout rate applied inside the attention sublayer, by default 0.05. expansion_factor : float, optional FFN intermediate width multiplier (before the 2/3 GLU correction), by default 4.0. glu_variant : str, optional GLU activation variant for the FFN. One of ``"swiglu"``, ``"geglu"``, ``"reglu"``, or ``"liglu"``, by default ``"swiglu"``. kernel_initializer : str, optional Initializer for kernel weights, by default ``"glorot_uniform"``. use_bias : bool, optional Whether to include bias terms in dense layers, by default False. layer_norm : bool, optional Whether to apply Pre-LN RMSNorm before each sublayer, by default True. num_inducing_points : int or None, optional If set, uses InducedSetAttention (ISAB) blocks with this many inducing points instead of standard SetAttention (SAB) blocks. seed_dim : int or None, optional Dimensionality of the PMA seed vectors. If None, defaults to ``embed_dims[-1]``. **kwargs Additional keyword arguments passed to the base layer. """ super().__init__(**kwargs) check_lengths_same(embed_dims, num_heads) shared_kwargs = dict( dropout=dropout, expansion_factor=expansion_factor, glu_variant=glu_variant, kernel_initializer=kernel_initializer, use_bias=use_bias, layer_norm=layer_norm, ) self.attention_blocks = [] for i in range(len(embed_dims)): block_kwargs = shared_kwargs | dict(num_heads=num_heads[i], embed_dim=embed_dims[i]) if num_inducing_points is None: block = SetAttention(**block_kwargs) else: block = InducedSetAttention(num_inducing_points=num_inducing_points, **block_kwargs) self.attention_blocks.append(block) self.pooling_by_attention = PoolingByMultiHeadAttention( num_heads=num_heads[-1], embed_dim=embed_dims[-1], num_seeds=num_seeds, seed_dim=seed_dim, **shared_kwargs, ) self.output_projector = keras.layers.Dense(units=summary_dim) self.summary_dim = summary_dim
[docs] def call(self, x: Tensor, training: bool = False, attention_mask: Tensor = None) -> Tensor: """Compresses the input set into a summary vector of size ``summary_dim``. Parameters ---------- x : Tensor Input of shape ``(batch_size, set_size, input_dim)``. training : bool, optional Passed to dropout and norm layers, by default False. attention_mask : Tensor, optional Boolean mask of shape ``(B, T, T)`` where 1 = attend, 0 = mask. Returns ------- Tensor Output of shape ``(batch_size, summary_dim)``. """ for layer in self.attention_blocks: x = layer(x, training=training, attention_mask=attention_mask) x = self.pooling_by_attention(x, training=training) x = self.output_projector(x) return x