Source code for bayesflow.simulators.simulator

from collections.abc import Callable
import numpy as np

from bayesflow.types import Shape
from bayesflow.utils import tree_concatenate
from bayesflow.utils.decorators import allow_batch_size


[docs] class Simulator:
[docs] @allow_batch_size def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: raise NotImplementedError
[docs] @allow_batch_size def rejection_sample( self, batch_shape: Shape, predicate: Callable[[dict[str, np.ndarray]], np.ndarray], *, axis: int = 0, sample_size: int = None, **kwargs, ) -> dict[str, np.ndarray]: if sample_size is None: sample_shape = batch_shape else: sample_shape = list(batch_shape) sample_shape[axis] = sample_size sample_shape = tuple(sample_shape) result = {} while not result or next(iter(result.values())).shape[axis] < batch_shape[axis]: # get a batch of samples samples = self.sample(sample_shape, **kwargs) # get acceptance mask and turn into indices accept = predicate(samples) if not isinstance(accept, np.ndarray): raise RuntimeError("Predicate must return a numpy array.") if accept.shape != (sample_shape[axis],): raise RuntimeError( f"Predicate return array must have shape {(sample_shape[axis],)}. Received: {accept.shape}." ) if not accept.dtype == "bool": # we could cast, but this tends to hide mistakes in the predicate raise RuntimeError(f"Predicate must return a boolean type array. Got dtype={accept.dtype}") (accept,) = np.nonzero(accept) if not np.any(accept): # no samples accepted, skip continue # apply acceptance mask samples = {key: np.take(value, accept, axis=axis) for key, value in samples.items()} # concatenate with previous samples if not result: result = samples else: result = tree_concatenate([result, samples], axis=axis, numpy=True) return result