Source code for bayesflow.datasets.online_dataset

import keras
import numpy as np

from bayesflow.adapters import Adapter
from bayesflow.simulators.simulator import Simulator


[docs] class OnlineDataset(keras.utils.PyDataset): """ A dataset that is generated on-the-fly. """ def __init__( self, simulator: Simulator, batch_size: int, num_batches: int, adapter: Adapter | None, **kwargs, ): super().__init__(**kwargs) self.batch_size = batch_size self._num_batches = num_batches self.adapter = adapter self.simulator = simulator def __getitem__(self, item: int) -> dict[str, np.ndarray]: batch = self.simulator.sample((self.batch_size,)) if self.adapter is not None: batch = self.adapter(batch) return batch @property def num_batches(self) -> int: return self._num_batches