Source code for bayesflow.approximators.approximator

from collections.abc import Mapping, Sequence
from typing import Any

import multiprocessing as mp

import numpy as np
import keras

from bayesflow.adapters import Adapter
from bayesflow.datasets import OnlineDataset
from bayesflow.simulators import Simulator
from bayesflow.types import Tensor
from bayesflow.utils import find_batch_size, filter_kwargs, concatenate_valid_shapes, logging
from bayesflow.utils.serialization import deserialize, serialize

from .backend_approximators import BackendApproximator


[docs] class Approximator(BackendApproximator): """Base class for all BayesFlow approximators.""" # Mask routing: {data_key → network_kwarg_name}. # Subclasses can narrow these to the masks they actually support. _SUMMARY_MASK_KEYS: dict[str, str] = { "summary_attention_mask": "attention_mask", "summary_mask": "mask", } _INFERENCE_MASK_KEYS: dict[str, str] = { "inference_attention_mask": "attention_mask", "inference_mask": "mask", } @staticmethod def _collect_mask_kwargs(mapping: dict[str, str], source: Mapping) -> dict: """Build a kwargs dict from a ``{source_key: target_kwarg_name}`` mapping. Looks up each *source_key* in *source*; when the value is not ``None``, it is added to the result under *target_kwarg_name*. """ return {target: source[key] for key, target in mapping.items() if source.get(key) is not None} @property def standardize_layers(self): """Shortcut to the standardizer's per-variable layers.""" if hasattr(self, "standardizer"): return self.standardizer.standardize_layers
[docs] def build(self, data_shapes: Mapping[str, tuple[int] | Mapping[str, Mapping]]): """ Template method for building all network components. This method orchestrates the build process by: 1. Building the summary network (if present) and caching its output shape 2. Enriching data_shapes with computed values for hooks to access 3. Calling hook methods in the proper sequence 4. Marking as built Hooks receive an enriched data_shapes dict that includes "_summary_outputs" if a summary network was built, so they don't need to recompute this value. """ # Build standardizer first self._build_standardization_layers(data_shapes) # Build summary network once at template level and cache output shape summary_outputs_shape = self._build_summary_network(data_shapes) # Enrich data_shapes with computed summary output shape for hooks to access (if summary network was built) enriched_shapes = data_shapes.copy() if summary_outputs_shape is not None: enriched_shapes["_summary_outputs"] = summary_outputs_shape self._build_inference_network(enriched_shapes) self.built = True
def _build_standardization_layers(self, data_shapes: Mapping[str, tuple[int] | Mapping]): """ Helper method: builds the standardizer if present (default behavior for most approximators). """ if hasattr(self, "standardizer") and not self.standardizer.built: self.standardizer.build(data_shapes) def _build_summary_network(self, data_shapes: Mapping[str, tuple[int] | Mapping]) -> tuple[int] | dict | None: """ Helper method: builds the summary network if present. Subclasses can call this to build their summary network. Returns ------- output_shape : tuple or dict or None The output shape of the summary network, or None if no summary network. """ if not hasattr(self, "summary_network") or self.summary_network is None: return None if not self.summary_network.built: self.summary_network.build(data_shapes["summary_variables"]) return self.summary_network.compute_output_shape(data_shapes["summary_variables"]) def _build_inference_network(self, data_shapes: Mapping[str, tuple[int] | Mapping]): """ Hook method: subclasses implement to build their inference network(s). Subclasses should call _build_summary_network() internally if needed. """ if not self.inference_network.built: summary_outputs_shape = data_shapes.get("_summary_outputs") inference_conditions_shape = concatenate_valid_shapes( [data_shapes.get("inference_conditions"), summary_outputs_shape], axis=-1 ) self.inference_network.build(data_shapes["inference_variables"], inference_conditions_shape) def _prepare_conditions( self, data: Mapping[str, np.ndarray], *, stage: str = "inference", **adapter_kwargs, ) -> tuple[dict[str, Tensor], Tensor | None, Tensor | None]: """Adapt raw user data, tensorize, standardize conditions, and resolve. Standard inference-time pipeline shared across all approximators: 1. Apply the adapter (``strict=False``) 2. Convert all values to tensors 3. Standardize ``inference_conditions`` and ``summary_variables`` 4. Resolve conditions via the summary network (if present) Parameters ---------- data : Mapping[str, np.ndarray] Raw user data dictionary. stage : str, optional Stage for standardization (default is ``"inference"``). **adapter_kwargs Extra keyword arguments forwarded to the adapter. Returns ------- resolved_conditions : Tensor or None Standardized inference conditions concatenated with summary outputs. adapted : dict[str, Tensor] The full adapted and tensorized dictionary. summary_outputs : Tensor or None Raw summary network outputs, or ``None`` if no summary network. """ adapted = self.adapter(data, strict=False, **adapter_kwargs) adapted = keras.tree.map_structure(keras.ops.convert_to_tensor, adapted) summary_kwargs = self._collect_mask_kwargs(self._SUMMARY_MASK_KEYS, adapted) resolved_conditions, summary_outputs = self._standardize_and_resolve( adapted.get("inference_conditions"), adapted.get("summary_variables"), stage=stage, **summary_kwargs, ) return resolved_conditions, adapted, summary_outputs def _standardize_and_resolve( self, inference_conditions: Tensor | None, summary_variables: Tensor | None, *, stage: str, purpose: str = "call", **summary_kwargs, ): """Standardize condition tensors and resolve via the summary network. Shared by both inference-time methods (via :meth:`_prepare_conditions`) and training-time ``compute_metrics`` implementations. Parameters ---------- inference_conditions : Tensor or None Inference conditions (pre-adapted tensors). summary_variables : Tensor or None Summary variables (pre-adapted tensors). stage : str Current stage (``"training"``, ``"validation"``, or ``"inference"``). purpose : str, optional Passed to :meth:`ConditionBuilder.resolve` — ``"call"`` for forward passes, ``"metrics"`` for training/validation (default is ``"call"``). **summary_kwargs Extra keyword arguments forwarded to :meth:`ConditionBuilder.resolve` and ultimately to the summary network's ``call`` / ``compute_metrics`` method (e.g. ``attention_mask``). Returns ------- resolved_conditions : Tensor or None Standardized inference conditions concatenated with summary outputs. summary_output : Tensor, dict, or None For ``purpose="call"``: summary network output tensor or ``None``. For ``purpose="metrics"``: dict of summary metrics. """ inference_conditions = self.standardizer.maybe_standardize( inference_conditions, key="inference_conditions", stage=stage ) summary_variables = self.standardizer.maybe_standardize(summary_variables, key="summary_variables", stage=stage) resolved_conditions, summary_output = self.condition_builder.resolve( self.summary_network, inference_conditions, summary_variables, stage=stage, purpose=purpose, **summary_kwargs, ) return resolved_conditions, summary_output
[docs] @classmethod def build_adapter( cls, inference_variables: str | Sequence[str], inference_conditions: str | Sequence[str] = None, summary_variables: str | Sequence[str] = None, sample_weight: str = None, summary_attention_mask: str = None, summary_mask: str = None, inference_attention_mask: str = None, inference_mask: str = None, ) -> Adapter: """Create a default :py:class:`~bayesflow.adapters.Adapter` for the approximator. Handles the common pipeline shared by all approximators: ``to_array -> convert_dtype -> concatenate -> keep``. Subclasses can call ``super().build_adapter(...)`` and apply additional steps to the returned adapter. Parameters ---------- inference_variables : str or Sequence[str] Names of the inference variables in the data dict. inference_conditions : str or Sequence[str], optional Names of the inference conditions in the data dict. summary_variables : str or Sequence[str], optional Names of the summary variables in the data dict. sample_weight : str, optional Name of the sample weight variable. summary_attention_mask : str, optional Name of the attention mask for the summary network. Forwarded as ``attention_mask`` to the summary network. summary_mask : str, optional Name of the padding/key mask for the summary network. Forwarded as ``mask`` to the summary network. inference_attention_mask : str, optional Name of the attention mask for the inference network. Forwarded as ``attention_mask`` to the inference network. inference_mask : str, optional Name of the padding/key mask for the inference network. Forwarded as ``mask`` to the inference network. """ if isinstance(inference_variables, str): inference_variables = [inference_variables] if isinstance(inference_conditions, str): inference_conditions = [inference_conditions] if isinstance(summary_variables, str): summary_variables = [summary_variables] adapter = Adapter() adapter.to_array() adapter.convert_dtype("float64", "float32") adapter.concatenate(inference_variables, into="inference_variables") if inference_conditions is not None: adapter.concatenate(inference_conditions, into="inference_conditions") if summary_variables is not None: adapter.as_set(summary_variables) adapter.concatenate(summary_variables, into="summary_variables") if sample_weight is not None: adapter.rename(sample_weight, "sample_weight") keep = ["inference_variables", "inference_conditions", "summary_variables", "sample_weight"] for canonical_name, user_name in { "summary_attention_mask": summary_attention_mask, "summary_mask": summary_mask, "inference_attention_mask": inference_attention_mask, "inference_mask": inference_mask, }.items(): if user_name is not None: adapter.rename(user_name, canonical_name) keep.append(canonical_name) adapter.keep(keep) return adapter
[docs] def build_dataset( self, *, batch_size: int = "auto", num_batches: int, adapter: Adapter = "auto", memory_budget: str | int = "auto", simulator: Simulator, workers: int = "auto", use_multiprocessing: bool = False, max_queue_size: int = 32, **kwargs, ) -> OnlineDataset: if batch_size == "auto": batch_size = find_batch_size(memory_budget=memory_budget, sample=simulator.sample(1)) logging.info(f"Using a batch size of {batch_size}.") if adapter == "auto": adapter = self.build_adapter(**filter_kwargs(kwargs, self.build_adapter)) if workers == "auto": workers = mp.cpu_count() logging.info(f"Using {workers} data loading workers.") workers = workers or 1 return OnlineDataset( simulator=simulator, batch_size=batch_size, num_batches=num_batches, adapter=adapter, workers=workers, use_multiprocessing=use_multiprocessing, max_queue_size=max_queue_size, )
[docs] def call(self, *args, **kwargs): return self.compute_metrics(*args, **kwargs)
[docs] def compile( self, *args, inference_metrics: Any = None, summary_metrics: Any = None, **kwargs, ): """ Compile the approximator, setting metrics on inference and summary networks if provided. Parameters ---------- inference_metrics : keras.Metric or Sequence[keras.Metric], optional Metric(s) to set on the inference_network. summary_metrics : keras.Metric or Sequence[keras.Metric], optional Metric(s) to set on the summary_network (if present). *args, **kwargs Additional arguments passed to the parent compile method. """ if inference_metrics: self.inference_network._metrics = inference_metrics if summary_metrics: if not hasattr(self, "summary_network") or self.summary_network is None: logging.warning("Ignoring summary metrics because there is no summary network.") else: self.summary_network._metrics = summary_metrics return super().compile(*args, **kwargs)
[docs] def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs): """ Trains the approximator on the provided dataset or on-demand data generated from the given simulator. If `dataset` is not provided, a dataset is built from the `simulator`. If the model has not been built, it will be built using a batch from the dataset. Parameters ---------- dataset : keras.utils.PyDataset, optional A dataset containing simulations for training. If provided, `simulator` must be None. simulator : Simulator, optional A simulator used to generate a dataset. If provided, `dataset` must be None. **kwargs Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`): batch_size : int or None, default='auto' Number of samples per gradient update. Do not specify if `dataset` is provided as a `keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function. epochs : int, default=1 Number of epochs to train the model. verbose : {"auto", 0, 1, 2}, default="auto" Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. callbacks : list of keras.callbacks.Callback, optional List of callbacks to apply during training. validation_split : float, optional Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays or tensors). validation_data : tuple or dataset, optional Data for validation, overriding `validation_split`. shuffle : bool, default=True Whether to shuffle the training data before each epoch (ignored for dataset generators). initial_epoch : int, default=0 Epoch at which to start training (useful for resuming training). steps_per_epoch : int or None, optional Number of steps (batches) before declaring an epoch finished. validation_steps : int or None, optional Number of validation steps per validation epoch. validation_batch_size : int or None, optional Number of samples per validation batch (defaults to `batch_size`). validation_freq : int, default=1 Specifies how many training epochs to run before performing validation. Returns ------- keras.callbacks.History A history object containing the training loss and metrics values. Raises ------ ValueError If both `dataset` and `simulator` are provided or neither is provided. """ if dataset is None: if simulator is None: raise ValueError("Received no data to fit on. Please provide either a dataset or a simulator.") logging.info(f"Building dataset from simulator instance of {simulator.__class__.__name__}.") dataset = self.build_dataset(simulator=simulator, **filter_kwargs(kwargs, self.build_dataset)) else: if simulator is not None: raise ValueError( "Received conflicting arguments. Please provide either a dataset or a simulator, but not both." ) logging.info(f"Fitting on dataset instance of {dataset.__class__.__name__}.") if not self.built: logging.info("Building on a test batch.") mock_data = dataset[0] mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data) self.build_from_data(mock_data) return super().fit(dataset=dataset, **kwargs)
[docs] def build_from_data(self, adapted_data: Mapping[str, Any]): """Build the approximator from adapted data by extracting shapes.""" self.build(keras.tree.map_structure(keras.ops.shape, adapted_data))
[docs] def compile_from_config(self, config): """Compile the approximator from a saved configuration.""" self.compile(**deserialize(config)) if hasattr(self, "optimizer") and self.built: self.optimizer.build(self.trainable_variables)
[docs] @classmethod def from_config(cls, config, custom_objects=None): """Deserialize and instantiate an approximator from configuration.""" return cls(**deserialize(config, custom_objects=custom_objects))
[docs] def get_compile_config(self): """ Serialize compile configuration for all network metrics. Collects metrics from inference_network and summary_network (if present), serializes them, and merges with parent class config. Returns ------- dict Configuration dictionary with serialized metrics. """ base_config = super().get_compile_config() or {} config = {} if hasattr(self, "inference_network") and self.inference_network is not None: config["inference_metrics"] = self.inference_network._metrics if hasattr(self, "summary_network") and self.summary_network is not None: config["summary_metrics"] = self.summary_network._metrics return base_config | serialize(config)
[docs] def summarize(self, conditions: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: """ Computes the learned summary statistics of given summary variables. The `conditions` dictionary is preprocessed using the `adapter` and passed through the summary network. Parameters ---------- conditions : Mapping[str, np.ndarray] Dictionary of simulated or real quantities as NumPy arrays. **kwargs : dict Additional keyword arguments for the adapter and the summary network. Returns ------- summaries : np.ndarray The learned summary statistics. Returns None if no summary network is present. """ if not hasattr(self, "summary_network") or self.summary_network is None: raise ValueError("Summary network is not available. This approximator does not support summarization.") if not hasattr(self, "adapter"): raise ValueError("Adapter is not available.") _, _, summary_outputs = self._prepare_conditions(conditions) return keras.ops.convert_to_numpy(summary_outputs)
def _batch_size_from_data(self, data: Mapping[str, any]) -> int: """Return the batch size from a training data dict. Relies on the ``"inference_variables"`` key, which is present in every approximator's training data. """ return keras.ops.shape(data["inference_variables"])[0]