bayesflow.attention module#

class bayesflow.attention.MultiHeadAttentionBlock(*args, **kwargs)[source]#

Bases: Model

Implements the MAB block from [1] which represents learnable cross-attention.

[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).

Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR.

__init__(input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs)[source]#

Creates a multihead attention block which will typically be used as part of a set transformer architecture according to [1]. Corresponds to standard cross-attention.

Parameters:
input_dimint

The dimensionality of the input data (last axis).

attention_settingsdict

A dictionary which will be unpacked as the arguments for the MultiHeadAttention layer See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention.

num_dense_fcint

The number of hidden layers for the internal feedforward network

dense_settingsdict

A dictionary which will be unpacked as the arguments for the Dense layer

use_layer_normboolean

Whether layer normalization before and after attention + feedforward

**kwargsdict, optional, default: {}

Optional keyword arguments passed to the __init__() method of tf.keras.Model

call(x, y, **kwargs)[source]#

Performs the forward pass through the attention layer.

Parameters:
xtf.Tensor

Input of shape (batch_size, set_size_x, input_dim)

ytf.Tensor

Input of shape (batch_size, set_size_y, input_dim)

Returns:
outtf.Tensor

Output of shape (batch_size, set_size_x, input_dim)

class bayesflow.attention.SelfAttentionBlock(*args, **kwargs)[source]#

Bases: Model

Implements the SAB block from [1] which represents learnable self-attention.

[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).

Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR.

__init__(input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs)[source]#

Creates a self-attention attention block which will typically be used as part of a set transformer architecture according to [1].

Parameters:
input_dimint

The dimensionality of the input data (last axis).

attention_settingsdict

A dictionary which will be unpacked as the arguments for the MultiHeadAttention layer See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention.

num_dense_fcint

The number of hidden layers for the internal feedforward network

dense_settingsdict

A dictionary which will be unpacked as the arguments for the Dense layer

use_layer_normboolean

Whether layer normalization before and after attention + feedforward

**kwargsdict, optional, default: {}

Optional keyword arguments passed to the __init__() method of tf.keras.Model

call(x, **kwargs)[source]#

Performs the forward pass through the self-attention layer.

Parameters:
xtf.Tensor

Input of shape (batch_size, set_size, input_dim)

Returns:
outtf.Tensor

Output of shape (batch_size, set_size, input_dim)

class bayesflow.attention.InducedSelfAttentionBlock(*args, **kwargs)[source]#

Bases: Model

Implements the ISAB block from [1] which represents learnable self-attention specifically designed to deal with large sets via a learnable set of “inducing points”.

[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).

Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR.

__init__(input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, num_inducing_points, **kwargs)[source]#

Creates a self-attention attention block with inducing points (ISAB) which will typically be used as part of a set transformer architecture according to [1].

Parameters:
input_dimint

The dimensionality of the input data (last axis).

attention_settingsdict

A dictionary which will be unpacked as the arguments for the MultiHeadAttention layer See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention.

num_dense_fcint

The number of hidden layers for the internal feedforward network

dense_settingsdict

A dictionary which will be unpacked as the arguments for the Dense layer

use_layer_normboolean

Whether layer normalization before and after attention + feedforward

num_inducing_pointsint

The number of inducing points. Should be lower than the smallest set size

**kwargsdict, optional, default: {}

Optional keyword arguments passed to the __init__() method of tf.keras.Model

call(x, **kwargs)[source]#

Performs the forward pass through the self-attention layer.

Parameters:
xtf.Tensor

Input of shape (batch_size, set_size, input_dim)

Returns:
outtf.Tensor

Output of shape (batch_size, set_size, input_dim)

class bayesflow.attention.PoolingWithAttention(*args, **kwargs)[source]#

Bases: Model

Implements the pooling with multihead attention (PMA) block from [1] which represents a permutation-invariant encoder for set-based inputs.

[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).

Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning (pp. 3744-3753). PMLR.

__init__(summary_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, num_seeds=1, **kwargs)[source]#

Creates a multihead attention block (MAB) which will perform cross-attention between an input set and a set of seed vectors (typically one for a single summary) with summary_dim output dimensions.

Could also be used as part of a DeepSet for representing learnabl instead of fixed pooling.

Parameters:
summary_dimint

The dimensionality of the learned permutation-invariant representation.

attention_settingsdict

A dictionary which will be unpacked as the arguments for the MultiHeadAttention layer See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention.

num_dense_fcint

The number of hidden layers for the internal feedforward network

dense_settingsdict

A dictionary which will be unpacked as the arguments for the Dense layer

use_layer_normboolean

Whether layer normalization before and after attention + feedforward

num_seedsint, optional, default: 1

The number of “seed vectors” to use. Each seed vector represents a permutation-invariant summary of the entire set. If you use num_seeds > 1, the resulting seeds will be flattened into a 2-dimensional output, which will have a dimensionality of num_seeds * summary_dim

**kwargsdict, optional, default: {}

Optional keyword arguments passed to the __init__() method of tf.keras.Model

call(x, **kwargs)[source]#

Performs the forward pass through the PMA block.

Parameters:
xtf.Tensor

Input of shape (batch_size, set_size, input_dim)

Returns:
outtf.Tensor

Output of shape (batch_size, num_seeds * summary_dim)