Source code for bayesflow.simulators.hierarchical_simulator
from collections.abc import Sequence
import keras
import numpy as np
from bayesflow.types import Shape
from bayesflow.utils.decorators import allow_batch_size
from .simulator import Simulator
[docs]
class HierarchicalSimulator(Simulator):
def __init__(self, hierarchy: Sequence[Simulator]):
"""
Initialize the hierarchical simulator with a sequence of simulators.
Parameters
----------
hierarchy : Sequence[Simulator]
A sequence of simulator instances representing each level of the hierarchy.
Each level's output is used as input for the next, with increasing batch dimensions.
"""
self.hierarchy = hierarchy
[docs]
@allow_batch_size
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
"""
Sample from a hierarchy of simulators.
Parameters
----------
batch_shape : Shape
A tuple where each element specifies the number of samples at the corresponding level
of the hierarchy. The total batch size increases multiplicatively through the levels.
**kwargs
Additional keyword arguments passed to each simulator. These are combined with outputs
from previous levels and repeated appropriately.
Returns
-------
output_data : dict of str to np.ndarray
A dictionary containing the outputs from the entire hierarchy. Outputs are reshaped to
match the hierarchical batch shape, i.e., with shape equal to `batch_shape + original_shape`.
"""
input_data = {}
output_data = {}
for level in range(len(self.hierarchy)):
# repeat input data for the next level
def repeat_level(x):
return np.repeat(x, batch_shape[level], axis=0)
input_data = keras.tree.map_structure(repeat_level, input_data)
# query the simulator flat at the current level
simulator = self.hierarchy[level]
query_shape = (np.prod(batch_shape[: level + 1]),)
data = simulator.sample(query_shape, **(kwargs | input_data))
# input data needs to have a flat batch shape
input_data |= data
# output data needs the restored batch shape
def restore_batch_shape(x):
return np.reshape(x, batch_shape[: level + 1] + x.shape[1:])
data = keras.tree.map_structure(restore_batch_shape, data)
output_data |= data
return output_data