Source code for bayesflow.workflows.ensemble_workflow

from collections.abc import Mapping, Sequence, Callable

import os

import numpy as np

import keras

from bayesflow.datasets import OnlineDataset, OfflineDataset, DiskDataset
from bayesflow.datasets.ensemble_dataset import EnsembleDataset
from bayesflow.networks import InferenceNetwork, ScoringRuleNetwork, SummaryNetwork
from bayesflow.simulators import Simulator
from bayesflow.adapters import Adapter
from bayesflow.approximators import EnsembleApproximator, ContinuousApproximator, ScoringRuleApproximator
from bayesflow.utils import find_inference_network, find_summary_network, logging, filter_kwargs

from .basic_workflow import BasicWorkflow


[docs] class EnsembleWorkflow(BasicWorkflow): def __init__( self, simulator: Simulator | None = None, adapter: Adapter | None = None, inference_networks: dict[str, InferenceNetwork | str] | InferenceNetwork | str = "coupling_flow", summary_networks: dict[str, SummaryNetwork | str] | SummaryNetwork | str | None = None, ensemble_size: int | None = None, share_inference_network: bool = False, initial_learning_rate: float = 5e-4, optimizer: keras.optimizers.Optimizer | type | None = None, checkpoint_filepath: str | None = None, checkpoint_name: str = "model", save_weights_only: bool = False, save_best_only: bool = False, inference_variables: Sequence[str] | str | None = None, inference_conditions: Sequence[str] | str | None = None, summary_variables: Sequence[str] | str | None = None, standardize: Sequence[str] | str | None = "inference_variables", **kwargs, ): _inference_networks = {} if isinstance(inference_networks, dict): if ensemble_size is not None: logging.warning( "Ignoring argument ensemble_size={ensemble_size}, " "because a dictionary was passed for `inference_networks`.", ensemble_size=ensemble_size, ) if share_inference_network: logging.warning( "Ignoring argument share_inference_network={share_inference_network}, " "because a dictionary was passed for `inference_networks`.", share_inference_network=share_inference_network, ) for k, v in inference_networks.items(): _inference_networks[k] = find_inference_network(v, **kwargs.get("inference_kwargs", {}).get(k, {})) else: if ensemble_size and ensemble_size > 1: inference_network = find_inference_network(inference_networks, **kwargs.get("inference_kwargs", {})) for member_idx in range(ensemble_size): member_key = f"{member_idx}" if share_inference_network: _inference_networks[member_key] = inference_network else: _inference_networks[member_key] = keras.models.clone_model(inference_network) elif isinstance(ensemble_size, int) and ensemble_size <= 1: raise ValueError("`ensemble_size` should be an integer greater than 1.") else: raise ValueError( "Either `inference_networks` is a dictionary of `InferenceNetwork`s " "or `ensemble_size` must be specified." ) _summary_networks = {} if isinstance(summary_networks, dict): for k, v in summary_networks.items(): if k not in _inference_networks.keys(): raise ValueError(f"A summary network was specified for {k}, but no inference network.") if v is not None: _summary_networks[k] = find_summary_network(v, **kwargs.get("summary_kwargs", {}).get(k, {})) elif summary_networks is not None: summary_network = find_summary_network(summary_networks, **kwargs.get("summary_kwargs", {})) for k in _inference_networks.keys(): _summary_networks[k] = summary_network self.simulator = simulator adapter = adapter or BasicWorkflow.default_adapter(inference_variables, inference_conditions, summary_variables) approximators = {} for k, v in _inference_networks.items(): if isinstance(v, ScoringRuleNetwork): constructor = ScoringRuleApproximator else: constructor = ContinuousApproximator approximators[k] = constructor( inference_network=v, summary_network=_summary_networks.get(k, None), adapter=adapter, standardize=standardize, **filter_kwargs(kwargs, keras.Model.__init__), ) self.approximator = EnsembleApproximator( approximators=approximators, **filter_kwargs(kwargs, keras.Model.__init__) ) self.member_names = tuple(self.approximator.approximators.keys()) self._init_optimizer(initial_learning_rate, optimizer, **kwargs.get("optimizer_kwargs", {})) self._init_checkpointing(checkpoint_filepath, checkpoint_name, save_weights_only, save_best_only) self.history = None
[docs] def fit_offline( self, data: Mapping[str, np.ndarray], epochs: int = 100, batch_size: int = 32, data_reuse: float = 1.0, keep_optimizer: bool = False, validation_data: Mapping[str, np.ndarray] | int = None, augmentations: Mapping[str, Callable] | Callable = None, **kwargs, ) -> keras.callbacks.History: """ Train the ensemble of approximators offline using a fixed dataset. This approach will be faster than online training, since no computation time is spent in generating new data for each batch, but it assumes that simulations can fit in memory. Parameters ---------- data : Mapping[str, np.ndarray] A dictionary containing training data where keys represent variable names and values are corresponding realizations. epochs : int, optional The number of training epochs, by default 100. Consider increasing this number for free-form inference networks, such as FlowMatching or ConsistencyModel, which generally need more epochs than CouplingFlows. batch_size : int, optional The batch size used for training, by default 32. data_reuse : float, optional Similarity of training data for ensemble members in ``[0, 1]``, by default 1.0. See also :py:class`bayesflow.datasets.EnsembleDataset`. keep_optimizer : bool, optional Whether to retain the current state of the optimizer after training, by default False. validation_data : Mapping[str, np.ndarray] or int, optional A dictionary containing validation data. If an integer is provided, that number of validation samples will be generated (if supported). By default, no validation data is used. augmentations : dict of str to Callable or Callable, optional Dictionary of augmentation functions to apply to each corresponding key in the batch or a function to apply to the entire batch (possibly adding new keys). If you provide a dictionary of functions, each function should accept one element of your output batch and return the corresponding transformed element. Otherwise, your function should accept the entire dictionary output and return a dictionary. Note - augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training. **kwargs : dict, optional Additional keyword arguments passed to the underlying `_fit` method. Returns ------- history : keras.callbacks.History A history object containing training history, where keys correspond to logged metrics (e.g., loss values) and values are arrays tracking metric evolution over epochs. """ dataset = OfflineDataset(data=data, batch_size=batch_size, adapter=self.adapter, augmentations=augmentations) dataset = EnsembleDataset(dataset, member_names=self.member_names, data_reuse=data_reuse) return self._fit( dataset, epochs, strategy="offline", keep_optimizer=keep_optimizer, validation_data=validation_data, **kwargs, )
[docs] def fit_online( self, epochs: int = 100, num_batches_per_epoch: int = 100, batch_size: int = 32, data_reuse: float = 1.0, keep_optimizer: bool = False, validation_data: Mapping[str, np.ndarray] | int = None, augmentations: Mapping[str, Callable] | Callable = None, **kwargs, ) -> keras.callbacks.History: """ Train the ensemble of approximators using an online data-generating process. The dataset is dynamically generated during training, making this approach suitable for scenarios where generating new simulations is computationally cheap. Parameters ---------- epochs : int, optional The number of training epochs, by default 100. num_batches_per_epoch : int, optional The number of batches generated per epoch, by default 100. batch_size : int, optional The batch size used for training, by default 32. data_reuse : float, optional Similarity of training data for ensemble members in ``[0, 1]``, by default 1.0. See also :py:class`bayesflow.datasets.EnsembleDataset`. keep_optimizer : bool, optional Whether to retain the current state of the optimizer after training, by default False. validation_data : Mapping[str, np.ndarray] or int, optional A dictionary containing validation data. If an integer is provided, that number of validation samples will be generated (if supported). By default, no validation data is used. augmentations : dict of str to Callable or Callable, optional Dictionary of augmentation functions to apply to each corresponding key in the batch or a function to apply to the entire batch (possibly adding new keys). If you provide a dictionary of functions, each function should accept one element of your output batch and return the corresponding transformed element. Otherwise, your function should accept the entire dictionary output and return a dictionary. Note - augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training. **kwargs : dict, optional Additional keyword arguments passed to the underlying `_fit` method. Returns ------- history : keras.callbacks.History A history object containing training history, where keys correspond to logged metrics (e.g., loss values) and values are arrays tracking metric evolution over epochs. """ dataset = OnlineDataset( simulator=self.simulator, batch_size=batch_size, num_batches=num_batches_per_epoch, adapter=self.adapter, augmentations=augmentations, ) dataset = EnsembleDataset(dataset, member_names=self.member_names, data_reuse=data_reuse) return self._fit( dataset, epochs, strategy="online", keep_optimizer=keep_optimizer, validation_data=validation_data, **kwargs )
[docs] def fit_disk( self, root: os.PathLike, pattern: str = "*.pkl", batch_size: int = 32, data_reuse: float = 1.0, load_fn: callable = None, epochs: int = 100, keep_optimizer: bool = False, validation_data: Mapping[str, np.ndarray] | int = None, augmentations: Mapping[str, Callable] | Callable = None, **kwargs, ) -> keras.callbacks.History: """ Train the ensemble of approximators using data stored on disk. This approach is suitable for large sets of simulations that don't fit in memory. Parameters ---------- root : os.PathLike The root directory containing the dataset files. pattern : str, optional A filename pattern to match dataset files, by default ``"*.pkl"``. batch_size : int, optional The batch size used for training, by default 32. data_reuse : float, optional Similarity of training data for ensemble members in ``[0, 1]``, by default 1.0. See also :py:class`bayesflow.datasets.EnsembleDataset`. load_fn : callable, optional A function to load dataset files. If None, a default loading function is used. epochs : int, optional The number of training epochs, by default 100. Consider increasing this number for free-form inference networks, such as FlowMatching or ConsistencyModel, which generally need more epochs than CouplingFlows. keep_optimizer : bool, optional Whether to retain the current state of the optimizer after training, by default False. validation_data : Mapping[str, np.ndarray] or int, optional A dictionary containing validation data. If an integer is provided, that number of validation samples will be generated (if supported). By default, no validation data is used. augmentations : dict of str to Callable or Callable, optional Dictionary of augmentation functions to apply to each corresponding key in the batch or a function to apply to the entire batch (possibly adding new keys). If you provide a dictionary of functions, each function should accept one element of your output batch and return the corresponding transformed element. Otherwise, your function should accept the entire dictionary output and return a dictionary. Note - augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training. **kwargs : dict, optional Additional keyword arguments passed to the underlying `_fit` method. Returns ------- history : keras.callbacks.History A history object containing training history, where keys correspond to logged metrics (e.g., loss values) and values are arrays tracking metric evolution over epochs. """ dataset = DiskDataset( root=root, pattern=pattern, batch_size=batch_size, load_fn=load_fn, adapter=self.adapter, augmentations=augmentations, ) dataset = EnsembleDataset(dataset, member_names=self.member_names, data_reuse=data_reuse) return self._fit( dataset, epochs, strategy="offline", keep_optimizer=keep_optimizer, validation_data=validation_data, **kwargs, )