Source code for bayesflow.utils.functional

from collections.abc import Callable, Mapping, Sequence
import keras
import numpy as np

from bayesflow.types import Shape


[docs] def batched_call( f: callable, batch_shape: Shape, args: Sequence[any] = (), kwargs: Mapping[str, any] = None, map_predicate: Callable[[any], bool] = None, flatten: bool = False, ) -> list: """Map f over the given batch shape with a for loop, preserving randomness unlike the keras built-in map apis. :param f: The function to call. :param batch_shape: The shape of the batch. :param args: Any number and type of positional arguments to f. Arguments indicated by `map_predicate` will be indexed over the first len(batch_shape) axes. :param kwargs: Any number and type of keyword arguments to f. Arguments indicated by `map_predicate` will be indexed over the first len(batch_shape) axes. :param map_predicate: A function that returns True if an argument should be indexed over the batch shape. By default, all array-like arguments are mapped. :param flatten: Whether to flatten the output. :return: A list of outputs of f for each element in the batch. """ if kwargs is None: kwargs = {} if map_predicate is None: def map_predicate(arg): if isinstance(arg, np.ndarray): return arg.ndim >= len(batch_shape) if keras.ops.is_tensor(arg): return keras.ops.ndim(arg) >= len(batch_shape) return False outputs = np.empty(batch_shape, dtype="object") for index in np.ndindex(batch_shape): map_args = [arg[index] if map_predicate(arg) else arg for arg in args] map_kwargs = {key: value[index] if map_predicate(value) else value for key, value in kwargs.items()} outputs[index] = f(*map_args, **map_kwargs) if flatten: outputs = outputs.flatten() return outputs.tolist()