Source code for bayesflow.simulators.simulator

from collections.abc import Callable
import numpy as np

from bayesflow.types import Shape
from bayesflow.utils import tree_concatenate
from bayesflow.utils.decorators import allow_batch_size


[docs] class Simulator:
[docs] @allow_batch_size def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: raise NotImplementedError
[docs] @allow_batch_size def rejection_sample( self, batch_shape: Shape, predicate: Callable[[dict[str, np.ndarray]], np.ndarray], *, axis: int = 0, sample_size: int = None, **kwargs, ) -> dict[str, np.ndarray]: if sample_size is None: sample_shape = batch_shape else: sample_shape = list(batch_shape) sample_shape[axis] = sample_size sample_shape = tuple(sample_shape) result = {} while not result or next(iter(result.values())).shape[axis] < batch_shape[axis]: # get a batch of samples samples = self.sample(sample_shape, **kwargs) # get acceptance mask and turn into indices accept = predicate(samples) if not isinstance(accept, np.ndarray): raise RuntimeError("Predicate must return a numpy array.") if accept.shape != (sample_shape[axis],): raise RuntimeError( f"Predicate return array must have shape {(sample_shape[axis],)}. Received: {accept.shape}." ) if not accept.dtype == "bool": # we could cast, but this tends to hide mistakes in the predicate raise RuntimeError(f"Predicate must return a boolean type array. Got dtype={accept.dtype}") if not np.any(accept): # no samples accepted, skip continue (accept,) = np.nonzero(accept) # apply acceptance mask samples = {key: np.take(value, accept, axis=axis) for key, value in samples.items()} # concatenate with previous samples if not result: result = samples else: result = tree_concatenate([result, samples], axis=axis, numpy=True) return result
[docs] @allow_batch_size def sample_batched( self, batch_shape: Shape, *, sample_size: int, **kwargs, ): """Sample the desired number of simulations in smaller batches. Limited resources, especially memory, can make it necessary to run simulations in smaller batches. The number of samples per simulated batch is specified by `sample_size`. Parameters ---------- batch_shape : Shape The desired output shape, as in :py:meth:`sample`. Will be rounded up to the next complete batch. sample_size : int The number of samples in each simulated batch. kwargs Additional keyword arguments passed to :py:meth:`sample`. """ def accept_all_predicate(x): return np.full((sample_size,), True) return self.rejection_sample(batch_shape, predicate=accept_all_predicate, sample_size=sample_size, **kwargs)