ModelComparisonSimulator#
- class bayesflow.simulators.ModelComparisonSimulator(simulators: Sequence[Simulator], p: Sequence[float] = None, logits: Sequence[float] = None, use_mixed_batches: bool = True, key_conflicts: Literal['drop', 'fill', 'error'] = 'drop', fill_value: float = nan, shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None)[source]#
Bases:
SimulatorA multimodel simulator useful for model comparison tasks.
This class wraps multiple
Simulatorinstances and produces batched outputs that include a one-hotinference_variablesvector indicating which simulator generated each sample. It supports two sampling modes:mixed batches (default) - each element in the batch may originate from a different simulator; the number of draws per model is drawn from a multinomial with probabilities given by
softmax(logits).single-model batches - the entire batch is drawn from a single simulator chosen according to the model probabilities.
A shared simulator may optionally provide additional data that is passed to every model’s sampling call. Key-conflict policies control how incompatible outputs across simulators are handled (drop, fill, or error).
- Parameters:
- simulatorsSequence[Simulator]
A sequence of simulator instances, each representing a different model.
- pSequence[float], optional
A sequence of probabilities associated with each simulator. Must sum to 1. Mutually exclusive with
logits.- logitsSequence[float], optional
A sequence of logits corresponding to model probabilities. Mutually exclusive with
p. If neitherpnorlogitsis provided, uniform logits are assumed.- use_mixed_batchesbool, optional
Whether to draw samples in a batch from different models.
If
True(default), each sample in a batch may come from a different model.If
False, the entire batch is drawn from a single model selected according to the model probabilities.
- key_conflicts{“drop”, “fill”, “error”}, optional
Policy for handling keys missing from some model outputs when mixing batches.
"drop"(default): drop conflicting keys from the batch output."fill": fill missing keys withfill_value."error": raise an error on conflicts.
- fill_valuefloat, optional
If
key_conflicts=="fill", missing keys are filled with this value.- shared_simulatorSimulator or Callable, optional
A shared simulator providing outputs to every model. If a callable is passed it is wrapped in a
LambdaSimulatorwith batching enabled.
- sample(batch_shape: tuple[int, ...], **kwargs) dict[str, ndarray][source]#
Sample from the model comparison simulator.
- Parameters:
- batch_shapeShape
The shape of the batch to sample. Typically, a tuple indicating the number of samples, but the user can also supply an int.
- **kwargs
Additional keyword arguments passed to each simulator. These may include outputs from the shared simulator.
- Returns:
- datadict of str to np.ndarray
- A dictionary containing the sampled outputs. Includes:
outputs from the selected simulator(s)
optionally, outputs from the shared simulator
“inference_variables”: an array indicating the model origin of each sample
- rejection_sample(batch_shape: tuple[int, ...], predicate: Callable[[dict[str, ndarray]], ndarray], *, axis: int = 0, sample_size: int = None, **kwargs) dict[str, ndarray]#
- sample_batched(batch_shape: tuple[int, ...], *, sample_size: int, **kwargs)#
Sample the desired number of simulations in smaller batches.
Limited resources, especially memory, can make it necessary to run simulations in smaller batches. The number of samples per simulated batch is specified by sample_size.