Source code for bayesflow.attention

# Copyright (c) 2022 The BayesFlow Developers

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


import tensorflow as tf
from tensorflow.keras.layers import Dense, LayerNormalization, MultiHeadAttention
from tensorflow.keras.models import Sequential


[docs] class MultiHeadAttentionBlock(tf.keras.Model): """Implements the MAB block from [1] which represents learnable cross-attention. [1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019). Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR. """
[docs] def __init__(self, input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs): """Creates a multihead attention block which will typically be used as part of a set transformer architecture according to [1]. Corresponds to standard cross-attention. Parameters ---------- input_dim : int The dimensionality of the input data (last axis). attention_settings : dict A dictionary which will be unpacked as the arguments for the ``MultiHeadAttention`` layer See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention. num_dense_fc : int The number of hidden layers for the internal feedforward network dense_settings : dict A dictionary which will be unpacked as the arguments for the ``Dense`` layer use_layer_norm : boolean Whether layer normalization before and after attention + feedforward **kwargs : dict, optional, default: {} Optional keyword arguments passed to the __init__() method of tf.keras.Model """ super().__init__(**kwargs) self.att = MultiHeadAttention(**attention_settings) self.ln_pre = LayerNormalization() if use_layer_norm else None self.fc = Sequential([Dense(**dense_settings) for _ in range(num_dense_fc)]) self.fc.add(Dense(input_dim)) self.ln_post = LayerNormalization() if use_layer_norm else None
[docs] def call(self, x, y, **kwargs): """Performs the forward pass through the attention layer. Parameters ---------- x : tf.Tensor Input of shape (batch_size, set_size_x, input_dim) y : tf.Tensor Input of shape (batch_size, set_size_y, input_dim) Returns ------- out : tf.Tensor Output of shape (batch_size, set_size_x, input_dim) """ h = x + self.att(x, y, y, **kwargs) if self.ln_pre is not None: h = self.ln_pre(h, **kwargs) out = h + self.fc(h, **kwargs) if self.ln_post is not None: out = self.ln_post(out, **kwargs) return out
[docs] class SelfAttentionBlock(tf.keras.Model): """Implements the SAB block from [1] which represents learnable self-attention. [1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019). Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR. """
[docs] def __init__(self, input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs): """Creates a self-attention attention block which will typically be used as part of a set transformer architecture according to [1]. Parameters ---------- input_dim : int The dimensionality of the input data (last axis). attention_settings : dict A dictionary which will be unpacked as the arguments for the ``MultiHeadAttention`` layer See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention. num_dense_fc : int The number of hidden layers for the internal feedforward network dense_settings : dict A dictionary which will be unpacked as the arguments for the ``Dense`` layer use_layer_norm : boolean Whether layer normalization before and after attention + feedforward **kwargs : dict, optional, default: {} Optional keyword arguments passed to the __init__() method of tf.keras.Model """ super().__init__(**kwargs) self.mab = MultiHeadAttentionBlock(input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm)
[docs] def call(self, x, **kwargs): """Performs the forward pass through the self-attention layer. Parameters ---------- x : tf.Tensor Input of shape (batch_size, set_size, input_dim) Returns ------- out : tf.Tensor Output of shape (batch_size, set_size, input_dim) """ return self.mab(x, x, **kwargs)
[docs] class InducedSelfAttentionBlock(tf.keras.Model): """Implements the ISAB block from [1] which represents learnable self-attention specifically designed to deal with large sets via a learnable set of "inducing points". [1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019). Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR. """
[docs] def __init__( self, input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, num_inducing_points, **kwargs ): """Creates a self-attention attention block with inducing points (ISAB) which will typically be used as part of a set transformer architecture according to [1]. Parameters ---------- input_dim : int The dimensionality of the input data (last axis). attention_settings : dict A dictionary which will be unpacked as the arguments for the ``MultiHeadAttention`` layer See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention. num_dense_fc : int The number of hidden layers for the internal feedforward network dense_settings : dict A dictionary which will be unpacked as the arguments for the ``Dense`` layer use_layer_norm : boolean Whether layer normalization before and after attention + feedforward num_inducing_points : int The number of inducing points. Should be lower than the smallest set size **kwargs : dict, optional, default: {} Optional keyword arguments passed to the __init__() method of tf.keras.Model """ super().__init__(**kwargs) init = tf.keras.initializers.GlorotUniform() self.I = tf.Variable(init(shape=(num_inducing_points, input_dim)), name="I", trainable=True) self.mab0 = MultiHeadAttentionBlock(input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm) self.mab1 = MultiHeadAttentionBlock(input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm)
[docs] def call(self, x, **kwargs): """Performs the forward pass through the self-attention layer. Parameters ---------- x : tf.Tensor Input of shape (batch_size, set_size, input_dim) Returns ------- out : tf.Tensor Output of shape (batch_size, set_size, input_dim) """ batch_size = tf.shape(x)[0] I_expanded = self.I[None, ...] I_tiled = tf.tile(I_expanded, [batch_size, 1, 1]) h = self.mab0(I_tiled, x, **kwargs) return self.mab1(x, h, **kwargs)
[docs] class PoolingWithAttention(tf.keras.Model): """Implements the pooling with multihead attention (PMA) block from [1] which represents a permutation-invariant encoder for set-based inputs. [1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019). Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR. """
[docs] def __init__( self, summary_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, num_seeds=1, **kwargs ): """Creates a multihead attention block (MAB) which will perform cross-attention between an input set and a set of seed vectors (typically one for a single summary) with summary_dim output dimensions. Could also be used as part of a ``DeepSet`` for representing learnabl instead of fixed pooling. Parameters ---------- summary_dim : int The dimensionality of the learned permutation-invariant representation. attention_settings : dict A dictionary which will be unpacked as the arguments for the ``MultiHeadAttention`` layer See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention. num_dense_fc : int The number of hidden layers for the internal feedforward network dense_settings : dict A dictionary which will be unpacked as the arguments for the ``Dense`` layer use_layer_norm : boolean Whether layer normalization before and after attention + feedforward num_seeds : int, optional, default: 1 The number of "seed vectors" to use. Each seed vector represents a permutation-invariant summary of the entire set. If you use ``num_seeds > 1``, the resulting seeds will be flattened into a 2-dimensional output, which will have a dimensionality of ``num_seeds * summary_dim`` **kwargs : dict, optional, default: {} Optional keyword arguments passed to the __init__() method of tf.keras.Model """ super().__init__(**kwargs) self.mab = MultiHeadAttentionBlock( summary_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs ) init = tf.keras.initializers.GlorotUniform() self.seed_vec = tf.Variable(init(shape=(num_seeds, summary_dim)), name="seed_vec", trainable=True) self.fc = Sequential([Dense(**dense_settings) for _ in range(num_dense_fc)]) self.fc.add(Dense(summary_dim))
[docs] def call(self, x, **kwargs): """Performs the forward pass through the PMA block. Parameters ---------- x : tf.Tensor Input of shape (batch_size, set_size, input_dim) Returns ------- out : tf.Tensor Output of shape (batch_size, num_seeds * summary_dim) """ out = self.fc(x) batch_size = tf.shape(x)[0] seed_expanded = self.seed_vec[None, ...] seed_tiled = tf.tile(seed_expanded, [batch_size, 1, 1]) out = self.mab(seed_tiled, out, **kwargs) return tf.reshape(out, (tf.shape(out)[0], -1))