ModelComparisonSimulator#

class bayesflow.simulators.ModelComparisonSimulator(simulators: Sequence[Simulator], p: Sequence[float] = None, logits: Sequence[float] = None, use_mixed_batches: bool = True, shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None)[source]#

Bases: Simulator

Wraps a sequence of simulators for use with a model comparison approximator.

Initialize a multimodel simulator that can generate data for mixture / model comparison problems.

Parameters:
simulatorsSequence[Simulator]

A sequence of simulator instances, each representing a different model.

pSequence[float], optional

A sequence of probabilities associated with each simulator. Must sum to 1. Mutually exclusive with logits.

logitsSequence[float], optional

A sequence of logits corresponding to model probabilities. Mutually exclusive with p. If neither p nor logits is provided, defaults to uniform logits.

use_mixed_batchesbool, optional

If True, samples in a batch are drawn from different models. If False, the entire batch is drawn from a single model chosen according to the model probabilities. Default is True.

shared_simulatorSimulator or Callable, optional

A shared simulator whose outputs are passed to all model simulators. If a function is provided, it is wrapped in a LambdaSimulator with batching enabled.

sample(batch_shape: tuple[int, ...], **kwargs) dict[str, ndarray][source]#

Sample from the model comparison simulator.

Parameters:
batch_shapeShape

The shape of the batch to sample. Typically, a tuple indicating the number of samples, but the user can also supply an int.

**kwargs

Additional keyword arguments passed to each simulator. These may include outputs from the shared simulator.

Returns:
datadict of str to np.ndarray
A dictionary containing the sampled outputs. Includes:
  • outputs from the selected simulator(s)

  • optionally, outputs from the shared simulator

  • “model_indices”: a one-hot encoded array indicating the model origin of each sample

rejection_sample(batch_shape: tuple[int, ...], predicate: Callable[[dict[str, ndarray]], ndarray], *, axis: int = 0, sample_size: int = None, **kwargs) dict[str, ndarray]#