Source code for bayesflow.networks.fusion_network.fusion_network

from collections.abc import Mapping
from ..summary_network import SummaryNetwork
from bayesflow.utils.serialization import deserialize, serializable, serialize
from bayesflow.types import Tensor, Shape
import keras
from keras import ops


[docs] @serializable("bayesflow.networks") class FusionNetwork(SummaryNetwork): def __init__( self, backbones: Mapping[str, keras.Layer], head: keras.Layer | None = None, **kwargs, ): """(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from multi-modal data. Networks and inputs are passed as dictionaries with corresponding keys, so that each input is processed by the correct summary network. This means the "summary_variables" entry to the approximator has to be a dictionary, which can be achieved using the :py:meth:`bayesflow.adapters.Adapter.group` method. This network implements _late_ fusion. The output of the individual summary networks is concatenated, and can be further processed by another neural network (`head`). Parameters ---------- backbones : dict A dictionary with names of inputs as keys and corresponding summary networks as values. head : keras.Layer, optional A network to further process the concatenated outputs of the summary networks. By default, the concatenated outputs are returned without further processing. **kwargs Additional keyword arguments that are passed to the :py:class:`~bayesflow.networks.SummaryNetwork` base class. """ super().__init__(**kwargs) self.backbones = backbones self.head = head self._ordered_keys = sorted(list(self.backbones.keys()))
[docs] def build(self, inputs_shape: Mapping[str, Shape]): if self.built: return output_shapes = [] for k, shape in inputs_shape.items(): if not self.backbones[k].built: self.backbones[k].build(shape) output_shapes.append(self.backbones[k].compute_output_shape(shape)) if self.head and not self.head.built: fusion_input_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes)) self.head.build(fusion_input_shape) self.built = True
[docs] def compute_output_shape(self, inputs_shape: Mapping[str, Shape]): output_shapes = [] for k, shape in inputs_shape.items(): output_shapes.append(self.backbones[k].compute_output_shape(shape)) output_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes)) if self.head: output_shape = self.head.compute_output_shape(output_shape) return output_shape
[docs] def call(self, inputs: Mapping[str, Tensor], training=False): """ Parameters ---------- inputs : dict[str, Tensor] Each value in the dictionary is the input to the summary network with the corresponding key. training : bool, optional Whether the model is in training mode, affecting layers like dropout and batch normalization. Default is False. """ outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys] outputs = ops.concatenate(outputs, axis=-1) if self.head is None: return outputs return self.head(outputs, training=training)
[docs] def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training", **kwargs) -> dict[str, Tensor]: """ Parameters ---------- inputs : dict[str, Tensor] Each value in the dictionary is the input to the summary network with the corresponding key. stage : bool, optional Whether the model is in training mode, affecting layers like dropout and batch normalization. Default is False. **kwargs Additional keyword arguments. """ if not self.built: self.build(keras.tree.map_structure(keras.ops.shape, inputs)) metrics = {"loss": [], "outputs": []} for k in self._ordered_keys: if isinstance(self.backbones[k], SummaryNetwork): metrics_k = self.backbones[k].compute_metrics(inputs[k], stage=stage, **kwargs) metrics["outputs"].append(metrics_k["outputs"]) if "loss" in metrics_k: metrics["loss"].append(metrics_k["loss"]) else: metrics["outputs"].append(self.backbones[k](inputs[k], training=stage == "training")) if len(metrics["loss"]) == 0: del metrics["loss"] else: metrics["loss"] = ops.sum(metrics["loss"]) metrics["outputs"] = ops.concatenate(metrics["outputs"], axis=-1) if self.head is not None: metrics["outputs"] = self.head(metrics["outputs"], training=stage == "training") return metrics
[docs] def get_config(self) -> dict: base_config = super().get_config() config = { "backbones": self.backbones, "head": self.head, } return base_config | serialize(config)
[docs] @classmethod def from_config(cls, config: dict, custom_objects=None): config = deserialize(config, custom_objects=custom_objects) return cls(**config)