batched_call#

bayesflow.utils.batched_call(f: callable, batch_shape: tuple[int, ...], args: Sequence[any] = (), kwargs: Mapping[str, any] = None, map_predicate: Callable[[any], bool] = None, flatten: bool = False) list[source]#

Map f over the given batch shape with a for loop, preserving randomness unlike the keras built-in map apis.

Parameters:
  • f – The function to call.

  • batch_shape – The shape of the batch.

  • args – Any number and type of positional arguments to f. Arguments indicated by map_predicate will be indexed over the first len(batch_shape) axes.

  • kwargs – Any number and type of keyword arguments to f. Arguments indicated by map_predicate will be indexed over the first len(batch_shape) axes.

  • map_predicate – A function that returns True if an argument should be indexed over the batch shape. By default, all array-like arguments are mapped.

  • flatten – Whether to flatten the output.

Returns:

A list of outputs of f for each element in the batch.