Source code for bayesflow.networks.sequential.sequential

from collections.abc import Sequence
import keras

from bayesflow.utils import layer_kwargs
from bayesflow.utils.serialization import deserialize, serializable, serialize


[docs] @serializable("bayesflow.networks") class Sequential(keras.Layer): """ A custom sequential model for managing a sequence of Keras layers. This class extends `keras.Layer` and provides functionality for building, calling, and serializing a sequence of layers. Unlike `keras.Sequential`, this implementation does not eagerly check input shapes, meaning it is compatible with both single inputs and sets. Parameters ---------- layers : keras.layer | Sequence[keras.layer] A sequence of Keras layers to be managed by this model. Can be passed by unpacking or as a single sequence. **kwargs : Additional keyword arguments passed to the base `keras.Layer` class. Notes ----- - This class differs from `keras.Sequential` in that it does not eagerly check input shapes. This means that it is compatible with both single inputs and sets. """ def __init__(self, *layers: keras.Layer | Sequence[keras.Layer], **kwargs): super().__init__(**layer_kwargs(kwargs)) if len(layers) == 1 and isinstance(layers[0], Sequence): layers = layers[0] self._layers = layers
[docs] def build(self, input_shape): if self.built: # building when the network is already built can cause issues with serialization # see https://github.com/keras-team/keras/issues/21147 return for layer in self._layers: layer.build(input_shape) input_shape = layer.compute_output_shape(input_shape)
[docs] def call(self, inputs, training=None, mask=None): x = inputs for layer in self._layers: kwargs = self._make_kwargs_for_layer(layer, training, mask) x = layer(x, **kwargs) return x
[docs] def compute_output_shape(self, input_shape): for layer in self._layers: input_shape = layer.compute_output_shape(input_shape) return input_shape
[docs] def get_config(self): base_config = super().get_config() base_config = layer_kwargs(base_config) config = { "layers": [serialize(layer) for layer in self._layers], } return base_config | config
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))
@property def layers(self): return self._layers @staticmethod def _make_kwargs_for_layer(layer, training, mask): kwargs = {} if layer._call_has_mask_arg: kwargs["mask"] = mask if layer._call_has_training_arg and training is not None: kwargs["training"] = training return kwargs