Source code for bayesflow.networks.time_series_network.time_series_network

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils.decorators import sanitize_input_shape
from .skip_recurrent import SkipRecurrentNet
from ..summary_network import SummaryNetwork


[docs] @serializable(package="bayesflow.networks") class TimeSeriesNetwork(SummaryNetwork): """ Implements a LSTNet Architecture as described in [1] [1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow, 2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM), Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190. """ def __init__( self, summary_dim: int = 16, filters: int | list | tuple = 32, kernel_sizes: int | list | tuple = 3, strides: int | list | tuple = 1, activation: str = "mish", kernel_initializer: str = "glorot_uniform", groups: int = None, recurrent_type: str = "gru", recurrent_dim: int = 128, bidirectional: bool = True, dropout: float = 0.05, skip_steps: int = 4, **kwargs, ): """ Initializes a hybrid convolutional-recurrent network for summarizing multivariate time series. This model combines a convolutional backbone with a recurrent module to efficiently process sequential data. The convolutional layers extract local features, which are then processed by a recurrent network for capturing long-range dependencies. The model supports various activation functions, kernel initializations, and bidirectional recurrence for enhanced representation learning. Parameters ---------- summary_dim : int, optional Dimensionality of the final summary representation. Default is 16. filters : int, list, or tuple, optional Number of filters in each convolutional layer. If an integer is provided, all layers will have the same number of filters. Default is 32. kernel_sizes : int, list, or tuple, optional Size of the convolutional kernels. If an integer is provided, all layers will have the same kernel size. Default is 3. strides : int, list, or tuple, optional Stride length for convolutional layers. If an integer is provided, all layers will have the same stride. Default is 1. activation : str, optional Activation function applied in the convolutional layers. Default is "mish". kernel_initializer : str, optional Initialization strategy for convolutional kernels, such as "glorot_uniform". Default is "glorot_uniform". groups : int, optional Number of groups for group normalization applied after each convolutional layer. Default is None. recurrent_type : str, optional Type of recurrent layer used for sequence modeling, such as "gru" or "lstm". Default is "gru". recurrent_dim : int, optional Number of hidden units in the recurrent layer. Default is 128. bidirectional : bool, optional Whether to use a bidirectional recurrent network. Default is True. dropout : float, optional Dropout rate applied in the recurrent module. Default is 0.05. skip_steps : int, optional Number of steps to skip in the recurrent network for efficiency. Default is 4. **kwargs Additional keyword arguments passed to the parent class. """ super().__init__(**kwargs) # Convolutional backbone -> can be extended with inception-like structure if not isinstance(filters, (list, tuple)): filters = (filters,) if not isinstance(kernel_sizes, (list, tuple)): kernel_sizes = (kernel_sizes,) if not isinstance(strides, (list, tuple)): strides = (strides,) self.conv_blocks = [] for f, k, s in zip(filters, kernel_sizes, strides): self.conv_blocks.append( keras.layers.Conv1D( filters=f, kernel_size=k, strides=s, activation=activation, kernel_initializer=kernel_initializer, padding="same", ) ) if groups is not None: self.conv_blocks.append(keras.layers.GroupNormalization(groups=groups)) # Recurrent and feedforward backbones self.recurrent = SkipRecurrentNet( hidden_dim=recurrent_dim, recurrent_type=recurrent_type, bidirectional=bidirectional, input_channels=filters[-1], skip_steps=skip_steps, dropout=dropout, ) self.output_projector = keras.layers.Dense(summary_dim) self.summary_dim = summary_dim
[docs] def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: """ Performs the forward pass of the hybrid convolutional-recurrent network. This function applies a sequence of convolutional layers followed by a recurrent module to extract spatial and temporal features from the input tensor. The final output is projected into a lower-dimensional summary representation using a dense layer. Parameters ---------- x : Tensor Input tensor representing the sequence data. training : bool, optional Whether the model is in training mode, affecting layers like dropout and batch normalization. Default is False. **kwargs Additional keyword arguments passed to the layers. Returns ------- output: Tensor Transformed tensor representing the summarized feature representation of the input sequence. """ for c in self.conv_blocks: x = c(x, training=training) x = self.recurrent(x, training=training) x = self.output_projector(x) return x
[docs] @sanitize_input_shape def build(self, input_shape): super().build(input_shape)