Source code for bayesflow.simulators.model_comparison_simulator

from collections.abc import Callable, Sequence
import numpy as np

from bayesflow.types import Shape
from bayesflow.utils import tree_concatenate
from bayesflow.utils.decorators import allow_batch_size

from bayesflow.utils import numpy_utils as npu
from bayesflow.utils import logging

from types import FunctionType
from typing import Literal

from .simulator import Simulator
from .lambda_simulator import LambdaSimulator


[docs] class ModelComparisonSimulator(Simulator): """A multimodel simulator useful for model comparison tasks. This class wraps multiple :class:`~bayesflow.simulators.Simulator` instances and produces batched outputs that include a one-hot ``inference_variables`` vector indicating which simulator generated each sample. It supports two sampling modes: * **mixed batches** (default) - each element in the batch may originate from a different simulator; the number of draws per model is drawn from a multinomial with probabilities given by ``softmax(logits)``. * **single-model batches** - the entire batch is drawn from a single simulator chosen according to the model probabilities. A shared simulator may optionally provide additional data that is passed to every model's sampling call. Key-conflict policies control how incompatible outputs across simulators are handled (drop, fill, or error). Parameters ---------- simulators : Sequence[Simulator] A sequence of simulator instances, each representing a different model. p : Sequence[float], optional A sequence of probabilities associated with each simulator. Must sum to 1. Mutually exclusive with ``logits``. logits : Sequence[float], optional A sequence of logits corresponding to model probabilities. Mutually exclusive with ``p``. If neither ``p`` nor ``logits`` is provided, uniform logits are assumed. use_mixed_batches : bool, optional Whether to draw samples in a batch from different models. * If ``True`` (default), each sample in a batch may come from a different model. * If ``False``, the entire batch is drawn from a single model selected according to the model probabilities. key_conflicts : {"drop", "fill", "error"}, optional Policy for handling keys missing from some model outputs when mixing batches. * ``"drop"`` (default): drop conflicting keys from the batch output. * ``"fill"``: fill missing keys with ``fill_value``. * ``"error"``: raise an error on conflicts. fill_value : float, optional If ``key_conflicts=="fill"``, missing keys are filled with this value. shared_simulator : Simulator or Callable, optional A shared simulator providing outputs to every model. If a callable is passed it is wrapped in a :class:`~bayesflow.simulators.LambdaSimulator` with batching enabled. """ def __init__( self, simulators: Sequence[Simulator], p: Sequence[float] = None, logits: Sequence[float] = None, use_mixed_batches: bool = True, key_conflicts: Literal["drop", "fill", "error"] = "drop", fill_value: float = np.nan, shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None, ): # constructor body unchanged self.simulators = simulators if isinstance(shared_simulator, FunctionType): shared_simulator = LambdaSimulator(shared_simulator, is_batched=True) self.shared_simulator = shared_simulator match logits, p: case (None, None): logits = [0.0] * len(simulators) case (None, logits): logits = logits case (p, None): p = np.array(p) if not np.isclose(np.sum(p), 1.0): raise ValueError("Probabilities must sum to 1.") logits = np.log(p) - np.log(1 - p) case _: raise ValueError("Received conflicting arguments. At most one of `p` or `logits` must be provided.") if len(logits) != len(simulators): raise ValueError(f"Length of logits ({len(logits)}) must match number of simulators ({len(simulators)}).") self.logits = logits self.use_mixed_batches = use_mixed_batches self.key_conflicts = key_conflicts self.fill_value = fill_value self._key_conflicts_warning = True
[docs] @allow_batch_size def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: """ Sample from the model comparison simulator. Parameters ---------- batch_shape : Shape The shape of the batch to sample. Typically, a tuple indicating the number of samples, but the user can also supply an int. **kwargs Additional keyword arguments passed to each simulator. These may include outputs from the shared simulator. Returns ------- data : dict of str to np.ndarray A dictionary containing the sampled outputs. Includes: - outputs from the selected simulator(s) - optionally, outputs from the shared simulator - "inference_variables": an array indicating the model origin of each sample """ data = {} if self.shared_simulator: data |= self.shared_simulator.sample(batch_shape, **kwargs) softmax_logits = npu.softmax(self.logits) num_models = len(self.simulators) # generate data randomly from each model (slower) if self.use_mixed_batches: model_counts = np.random.multinomial(n=batch_shape[0], pvals=softmax_logits) sims = [ simulator.sample(n, **(kwargs | data)) for simulator, n in zip(self.simulators, model_counts) if n > 0 ] sims = self._handle_key_conflicts(sims, model_counts) sims = tree_concatenate(sims, numpy=True) data |= sims model_indices = np.repeat(np.eye(num_models, dtype="int32"), model_counts, axis=0) # draw one model index for the whole batch (faster) else: model_index = np.random.choice(num_models, p=softmax_logits) data = self.simulators[model_index].sample(batch_shape, **(kwargs | data)) model_indices = npu.one_hot(np.full(batch_shape, model_index, dtype="int32"), num_models) return data | {"model_indices": model_indices}
def _handle_key_conflicts(self, sims, batch_sizes): batch_sizes = [b for b in batch_sizes if b > 0] keys, all_keys, common_keys, missing_keys = self._determine_key_conflicts(sims=sims) # all sims have the same keys if all_keys == common_keys: return sims if self.key_conflicts == "drop": sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims] return sims elif self.key_conflicts == "fill": combined_sims = {} for sim in sims: combined_sims = combined_sims | sim for i, sim in enumerate(sims): for missing_key in missing_keys[i]: shape = combined_sims[missing_key].shape shape = list(shape) shape[0] = batch_sizes[i] sim[missing_key] = np.full(shape=shape, fill_value=self.fill_value) return sims elif self.key_conflicts == "error": raise ValueError( "Different simulators provide outputs with different keys, cannot combine them into one batch." ) def _determine_key_conflicts(self, sims): keys = [set(sim.keys()) for sim in sims] all_keys = set.union(*keys) common_keys = set.intersection(*keys) missing_keys = [all_keys - k for k in keys] if all_keys == common_keys: return keys, all_keys, common_keys, missing_keys if self._key_conflicts_warning: # issue warning only once self._key_conflicts_warning = False if self.key_conflicts == "drop": logging.info( f"Incompatible simulator output. " f"The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}." ) elif self.key_conflicts == "fill": logging.info( f"Incompatible simulator output. " f"Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing " f"with value {self.fill_value}." ) return keys, all_keys, common_keys, missing_keys