from collections.abc import Sequence
import keras
from bayesflow.networks.summary_network import SummaryNetwork
from bayesflow.types import Tensor
from bayesflow.utils.serialization import serializable
from .mamba_block import MambaBlock
[docs]
@serializable
class Mamba(SummaryNetwork):
"""
Wraps a sequence of Mamba modules using the simple Mamba module from:
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
Copyright (c) 2023, Tri Dao, Albert Gu.
Example usage in a BayesFlow workflow as a summary network:
`summary_net = bayesflow.wrappers.Mamba(summary_dim=32)`
"""
def __init__(
self,
summary_dim: int = 16,
feature_dims: Sequence[int] = (64, 64),
state_dims: Sequence[int] = (64, 64),
conv_dims: Sequence[int] = (64, 64),
expand_dims: Sequence[int] = (2, 2),
bidirectional: bool = True,
dt_min: float = 0.001,
dt_max: float = 0.1,
dropout: float = 0.05,
device: str = "cuda",
**kwargs,
):
"""
A time-series summarization network using Mamba-based State Space Models (SSM). This model processes
sequential input data using a sequence of Mamba SSM layers (determined by the length of the tuples),
followed by optional pooling, dropout, and a dense layer for extracting summary statistics.
Parameters
----------
summary_dim : Sequence[int], optional
The output dimensionality of the summary statistics layer (default is 16).
feature_dims : Sequence[int], optional
The feature dimension for each mamba block, default is (64, 64),
state_dims : Sequence[int], optional
The dimensionality of the internal state in each Mamba block, default is (64, 64)
conv_dims : Sequence[int], optional
The dimensionality of the convolutional layer in each Mamba block, default is (32, 32)
expand_dims : Sequence[int], optional
The expansion factors for the hidden state in each Mamba block, default is (2, 2)
dt_min : float, optional
Minimum dynamic state evolution over time (default is 0.001).
dt_max : float, optional
Maximum dynamic state evolution over time (default is 0.1).
pooling : bool, optional
Whether to apply global average pooling (default is True).
dropout : int, float, or None, optional
Dropout rate applied before the summary layer (default is 0.5).
dropout: float, optional
Dropout probability; dropout is applied to the pooled summary vector.
device : str, optional
The computing device. Currently, only "cuda" is supported (default is "cuda").
**kwargs :
Additional keyword arguments passed to the `SummaryNetwork` parent class.
"""
super().__init__(**kwargs)
if device != "cuda":
raise NotImplementedError("MambaSSM only supports cuda as `device`.")
self.mamba_blocks = []
for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims):
mamba = MambaBlock(feature_dim, state_dim, conv_dim, expand, bidirectional, dt_min, dt_max, device)
self.mamba_blocks.append(mamba)
self.pooling_layer = keras.layers.GlobalAveragePooling1D()
self.dropout = keras.layers.Dropout(dropout)
self.summary_stats = keras.layers.Dense(summary_dim)
[docs]
def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor:
"""
Apply a sequence of Mamba blocks, followed by pooling, dropout, and summary statistics.
Parameters
----------
time_series : Tensor
Input tensor representing the time series data, typically of shape
(batch_size, sequence_length, feature_dim).
training : bool, optional
Whether the model is in training mode (default is False). Affects the behavior of
the inner dropout and norm layers.
**kwargs : dict
Additional keyword arguments (not used in this method).
Returns
-------
Tensor
Output tensor after applying Mamba blocks, pooling, dropout, and summary statistics.
"""
summary = time_series
for mamba_block in self.mamba_blocks:
summary = mamba_block(summary, training=training)
summary = self.pooling_layer(summary)
summary = self.dropout(summary, training=training)
summary = self.summary_stats(summary)
return summary