Source code for bayesflow.datasets.online_dataset
from collections.abc import Callable, Mapping, Sequence
import keras
import numpy as np
from bayesflow.adapters import Adapter
from bayesflow.simulators.simulator import Simulator
[docs]
class OnlineDataset(keras.utils.PyDataset):
"""
A dataset that generates simulations on-the-fly.
"""
def __init__(
self,
simulator: Simulator,
batch_size: int,
num_batches: int,
adapter: Adapter | None,
*,
stage: str = "training",
augmentations: Callable | Mapping[str, Callable] | Sequence[Callable] = None,
**kwargs,
):
"""
Initialize an OnlineDataset instance for infinite stream training.
Parameters
----------
simulator : Simulator
A simulator object with a `.sample(batch_shape)` method to generate data.
batch_size : int
Number of samples per batch.
num_batches : int
Total number of batches in the dataset.
adapter : Adapter or None
Optional adapter to transform the simulated batch.
stage : str, default="training"
Current stage (e.g., "training", "validation", etc.) used by the adapter.
augmentations : Callable or Mapping[str, Callable] or Sequence[Callable], optional
A single augmentation function, dictionary of augmentation functions, or sequence of augmentation functions
to apply to the batch.
If you provide a dictionary of functions, each function should accept one element
of your output batch and return the corresponding transformed element.
Otherwise, your function should accept the entire dictionary output and return a dictionary.
Note - augmentations are applied before the adapter is called and are generally
transforms that you only want to apply during training.
**kwargs
Additional keyword arguments passed to the base `PyDataset`.
"""
super().__init__(**kwargs)
self.batch_size = batch_size
self._num_batches = num_batches
self.adapter = adapter
self.simulator = simulator
self.stage = stage
self.augmentations = augmentations or []
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
"""
Generate one batch of data.
Parameters
----------
item : int
Index of the batch. Required by signature, but not used.
Returns
-------
dict of str to np.ndarray
A batch of simulated (and optionally augmented/adapted) data.
"""
batch = self.simulator.sample((self.batch_size,))
if self.augmentations is None:
pass
elif isinstance(self.augmentations, Mapping):
for key, fn in self.augmentations.items():
batch[key] = fn(batch[key])
elif isinstance(self.augmentations, Sequence):
for fn in self.augmentations:
batch = fn(batch)
elif isinstance(self.augmentations, Callable):
batch = self.augmentations(batch)
else:
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")
if self.adapter is not None:
batch = self.adapter(batch, stage=self.stage)
return batch
@property
def num_batches(self) -> int:
return self._num_batches