Source code for bayesflow.approximators.model_comparison_approximator

from collections.abc import Mapping, Sequence

import numpy as np

import keras

from bayesflow.adapters import Adapter
from bayesflow.datasets import OnlineDataset
from bayesflow.networks import ScoringRuleNetwork, SummaryNetwork
from bayesflow.scoring_rules import CrossEntropyScore
from bayesflow.simulators import ModelComparisonSimulator, Simulator
from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs, logging
from bayesflow.utils.serialization import serialize, serializable

from .approximator import Approximator
from .helpers import ConditionBuilder

from ..networks.helpers import Standardization


[docs] @serializable("bayesflow.approximators") class ModelComparisonApproximator(Approximator): """ Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are learned with a classifier. Uses a :class:`~bayesflow.networks.ScoringRuleNetwork` with a :class:`~bayesflow.scoring_rules.CrossEntropyScore` to map summary/condition inputs to class logits and train via categorical cross-entropy. Parameters ---------- num_models : int Number of models (simulators) that the approximator will compare. classifier_network : keras.Layer The network backbone (e.g., an MLP) that is used for model classification. Internally wrapped in a :class:`~bayesflow.networks.ScoringRuleNetwork` with a :class:`~bayesflow.scoring_rules.CrossEntropyScore`. The input to the classifier network is created by concatenating ``inference_conditions`` and (optional) output of the ``summary_network``. adapter : bf.adapters.Adapter, optional Adapter for data pre-processing. If ``None`` (default), an identity adapter is used that makes a shallow copy and passes data through unchanged. summary_network : bf.networks.SummaryNetwork, optional The summary network used for data summarization (default is None). The input of the summary network is ``summary_variables``. standardize : str | Sequence[str] | None The variables to standardize before passing to the networks. Can be any subset of ["inference_conditions", "summary_variables"]. (default is None, since model indices are one-hot encoded and should not be standardized). """ def __init__( self, *, num_models: int, classifier_network: keras.Layer, adapter: Adapter = None, summary_network: SummaryNetwork = None, standardize: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) self.num_models = num_models self.adapter = adapter if adapter is not None else Adapter() self.inference_network = ScoringRuleNetwork( scoring_rules={"cross_entropy": CrossEntropyScore()}, subnet=classifier_network, ) self.summary_network = summary_network self.condition_builder = ConditionBuilder() self.standardizer = Standardization(standardize)
[docs] def build_dataset( self, *, dataset: keras.utils.PyDataset = None, simulator: ModelComparisonSimulator = None, simulators: Sequence[Simulator] = None, **kwargs, ) -> OnlineDataset: if sum(arg is not None for arg in (dataset, simulator, simulators)) != 1: raise ValueError("Exactly one of dataset, simulator, or simulators must be provided.") if simulators is not None: simulator = ModelComparisonSimulator(simulators) return super().build_dataset(dataset=dataset, simulator=simulator, **kwargs)
[docs] def compute_metrics( self, inference_variables: Tensor, inference_conditions: Tensor = None, summary_variables: Tensor = None, sample_weight: Tensor = None, summary_attention_mask: Tensor = None, summary_mask: Tensor = None, inference_attention_mask: Tensor = None, inference_mask: Tensor = None, stage: str = "training", ) -> dict[str, Tensor]: """ Computes loss and tracks metrics for the classifier and summary networks. This method coordinates summary metric computation (if present), combines summary outputs with inference conditions, computes classifier logits and cross-entropy loss via the :class:`~bayesflow.scoring_rules.CrossEntropyScore`, and aggregates all tracked metrics into a single dictionary. Parameters ---------- inference_variables : Tensor One-hot encoded model indices (targets for classification). inference_conditions : Tensor, optional Conditioning variables for the classifier network (default is None). May be combined with summary network outputs if present. summary_variables : Tensor, optional Input tensor(s) for the summary network (default is None). Required if a summary network is present. sample_weight : Tensor, optional Weighting tensor for metric computation (default is None). summary_attention_mask : Tensor, optional Attention mask forwarded to the summary network (default is None). summary_mask : Tensor, optional Padding / key mask forwarded to the summary network (default is None). inference_attention_mask : Tensor, optional Accepted for API consistency but unused (model comparison uses an MLP classifier). inference_mask : Tensor, optional Padding / key mask forwarded to the classifier network (default is None). stage : str, optional Current training stage (e.g., "training", "validation", "inference"). Controls certain metric computations (default is "training"). Returns ------- metrics : dict[str, Tensor] Dictionary containing the total loss under the key "loss", as well as all tracked metrics for the classifier and summary networks. Each metric key is prefixed to indicate its source. """ masks = dict( summary_attention_mask=summary_attention_mask, summary_mask=summary_mask, inference_attention_mask=inference_attention_mask, inference_mask=inference_mask, ) summary_kwargs = self._collect_mask_kwargs(self._SUMMARY_MASK_KEYS, masks) inference_kwargs = self._collect_mask_kwargs(self._INFERENCE_MASK_KEYS, masks) resolved_conditions, summary_metrics = self._standardize_and_resolve( inference_conditions, summary_variables, stage=stage, purpose="metrics", **summary_kwargs ) inference_metrics = self.inference_network.compute_metrics( inference_variables, conditions=resolved_conditions, sample_weight=sample_weight, stage=stage, **inference_kwargs, ) if "loss" in summary_metrics: loss = inference_metrics["loss"] + summary_metrics["loss"] else: loss = inference_metrics.pop("loss") inference_metrics = {f"{key}/inference_{key}": value for key, value in inference_metrics.items()} summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()} metrics = {"loss": loss} | inference_metrics | summary_metrics return metrics
[docs] def fit( self, *, adapter: Adapter = "auto", dataset: keras.utils.PyDataset = None, simulator: ModelComparisonSimulator = None, simulators: Sequence[Simulator] = None, **kwargs, ): """ Trains the approximator on the provided dataset or on-demand generated from the given (multi-model) simulator. If `dataset` is not provided, a dataset is built from the `simulator`. If `simulator` is not provided, it will be built from a list of `simulators`. If the model has not been built, it will be built using a batch from the dataset. Parameters ---------- adapter : Adapter or 'auto', optional The data adapter that will make the simulated / real outputs neural-network friendly. dataset : keras.utils.PyDataset, optional A dataset containing simulations for training. If provided, `simulator` must be None. simulator : ModelComparisonSimulator, optional A simulator used to generate a dataset. If provided, `dataset` must be None. simulators: Sequence[Simulator], optional A list of simulators (one simulator per model). If provided, `dataset` must be None. **kwargs Additional keyword arguments passed to `keras.Model.fit()`, as described in: https://github.com/keras-team/keras/blob/v3.13.2/keras/src/backend/tensorflow/trainer.py#L314 Returns ------- keras.callbacks.History A history object containing the training loss and metrics values. Raises ------ ValueError If both `dataset` and `simulator` or `simulators` are provided or neither is provided. """ if dataset is not None: if simulator is not None or simulators is not None: raise ValueError( "Received conflicting arguments. Please provide either a dataset or a simulator, but not both." ) return super().fit(dataset=dataset, **kwargs) if adapter == "auto": logging.info("Building automatic data adapter.") adapter = self.build_adapter(num_models=self.num_models, **filter_kwargs(kwargs, self.build_adapter)) if simulator is not None: return super().fit(simulator=simulator, adapter=adapter, **kwargs) logging.info(f"Building model comparison simulator from {len(simulators)} simulators.") simulator = ModelComparisonSimulator(simulators=simulators) return super().fit(simulator=simulator, adapter=adapter, **kwargs)
[docs] def get_config(self): base_config = super().get_config() config = { "num_models": self.num_models, "adapter": self.adapter, "classifier_network": self.inference_network.subnet, "summary_network": self.summary_network, "standardize": self.standardizer.standardize, } return base_config | serialize(config)
[docs] def predict( self, *, conditions: Mapping[str, np.ndarray], probs: bool = True, **kwargs, ) -> np.ndarray: """ Predicts posterior model probabilities given input conditions. The `conditions` dictionary is preprocessed using the `adapter`. The output is converted to NumPy array after inference. Parameters ---------- conditions : Mapping[str, np.ndarray] Dictionary of conditioning variables as NumPy arrays. probs: bool, optional A flag indicating whether model probabilities (True) or logits (False) are returned. Default is True. **kwargs : dict Additional keyword arguments for the adapter and classifier. Returns ------- outputs: np.ndarray Predicted posterior model probabilities given `conditions`. """ resolved_conditions, adapted, _ = self._prepare_conditions(conditions, **kwargs) inference_kwargs = self._collect_mask_kwargs(self._INFERENCE_MASK_KEYS, adapted) output = self.inference_network(xz=None, conditions=resolved_conditions, **inference_kwargs) logits = output["cross_entropy"]["logits"] if probs: logits = keras.ops.softmax(logits) return keras.ops.convert_to_numpy(logits)