Source code for bayesflow.simulators.make_simulator

import inspect
from collections.abc import Callable, Mapping, Sequence
from functools import singledispatch
from types import FunctionType

import numpy as np

from .simulator import Simulator


[docs] @singledispatch def make_simulator(arg, *_, **__): """ This is a dispatch function that will accept a list of simulators (callables) returning dictionaries with simulated outputs. The outputs of simulators will be passed to following simulators if the latter accept keyword arguments associated with the keys of previous outputs. """ raise TypeError(f"Cannot infer simulator from {arg!r}.")
@make_simulator.register def _(simulator: Simulator): return simulator @make_simulator.register(FunctionType) def _(fn: Callable, **kwargs): from bayesflow.simulators import LambdaSimulator return LambdaSimulator(fn, **kwargs) @make_simulator.register(Sequence) def _( objs: Sequence[FunctionType], obj_kwargs: Mapping[str, dict[str, any]] = None, meta_fn: Callable[[], dict[str, np.ndarray]] = None, **kwargs, ): from bayesflow.simulators import LambdaSimulator, SequentialSimulator if obj_kwargs is None: obj_kwargs = {} # sanity check detected_names = {obj.__name__ for obj in objs if hasattr(obj, "__name__")} given_names = set(obj_kwargs.keys()) if not given_names.issubset(detected_names): unmatched_names = given_names - detected_names msg = ( f"Found at least one key in obj_kwargs that does not have a match in the object sequence:\n" f"{list(unmatched_names)!r}" ) if not all(hasattr(obj, "__name__") for obj in objs): msg += ( "\nThis can happen if the matching objects in the sequence do not have a __name__ attribute. " "Pass a dictionary instead to specify names explicitly." ) raise ValueError(msg) simulators = [] for obj in objs: if hasattr(obj, "__name__"): obj_kwargs = obj_kwargs.get(obj.__name__, {}) else: obj_kwargs = {} simulators.append(make_simulator(obj, **obj_kwargs)) if meta_fn is not None: if not inspect.signature(meta_fn).parameters: original_meta_fn = meta_fn def meta_fn(*_, **__): return original_meta_fn() meta = LambdaSimulator(meta_fn, is_batched=True) simulators = [meta, *simulators] return SequentialSimulator(simulators, **kwargs) @make_simulator.register(Mapping) def _( objs: Mapping[str, FunctionType], obj_kwargs: Mapping[str, dict[str, any]] = None, meta_fn: Callable[[], dict[str, np.ndarray]] = None, **kwargs, ): from bayesflow.simulators import LambdaSimulator, SequentialSimulator if obj_kwargs is None: obj_kwargs = {} # sanity check detected_names = set(objs.keys()) given_names = set(obj_kwargs.keys()) if not given_names.issubset(detected_names): unmatched_names = given_names - detected_names raise ValueError( f"Found at least one key in obj_kwargs that does not have a match in the object mapping:\n" f"{list(unmatched_names)!r}" ) simulators = [] for name, obj in objs.items(): obj_kwargs = obj_kwargs.get(name, {}) simulators.append(make_simulator(obj, **obj_kwargs)) if meta_fn is not None: meta = LambdaSimulator(meta_fn, is_batched=True) simulators = [meta, *simulators] return SequentialSimulator(simulators, **kwargs)