Source code for bayesflow.simulators.lambda_simulator
from collections.abc import Callable, Sequence, Mapping
import numpy as np
from bayesflow.utils import batched_call, filter_kwargs, tree_stack
from bayesflow.utils.decorators import allow_batch_size
from .simulator import Simulator
from ..types import Shape
[docs]
class LambdaSimulator(Simulator):
"""Implements a simulator based on a sampling function."""
def __init__(self, sample_fn: Callable[[Sequence[int]], Mapping[str, any]], *, is_batched: bool = False):
"""
Initialize a simulator based on a simple callable function
Parameters
----------
sample_fn : Callable[[Sequence[int]], Mapping[str, any]]
A function that generates samples. It should accept `batch_shape` as its first argument
(if `is_batched=True`), followed by keyword arguments.
is_batched : bool, optional
Whether the `sample_fn` is implemented to handle batched sampling directly.
If False, `sample_fn` will be called once per sample and results will be stacked.
Default is False.
"""
self.sample_fn = sample_fn
self.is_batched = is_batched
[docs]
@allow_batch_size
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
"""
Sample using the wrapped sampling function.
Parameters
----------
batch_shape : Shape
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
but an int can also be passed.
**kwargs
Additional keyword arguments passed to the sampling function. Only valid arguments
(as determined by the function's signature) are used.
Returns
-------
data : dict of str to np.ndarray
A dictionary of sampled outputs. Keys are output names and values are numpy arrays.
If `is_batched` is False, individual outputs are stacked along the first axis.
"""
# try to use only valid keyword-arguments
kwargs = filter_kwargs(kwargs, self.sample_fn)
if self.is_batched:
return self.sample_fn(batch_shape, **kwargs)
data = batched_call(self.sample_fn, batch_shape, kwargs=kwargs, flatten=True)
data = tree_stack(data, axis=0, numpy=True)
return data