Source code for bayesflow.datasets.rounds_dataset
import keras
import numpy as np
from bayesflow.adapters import Adapter
from bayesflow.simulators.simulator import Simulator
from bayesflow.utils import logging
[docs]
class RoundsDataset(keras.utils.PyDataset):
"""
A dataset that is generated on-the-fly at the beginning of every n-th epoch.
"""
def __init__(
self,
simulator: Simulator,
batch_size: int,
num_batches: int,
epochs_per_round: int,
adapter: Adapter | None,
**kwargs,
):
super().__init__(**kwargs)
self.batches = None
self._num_batches = num_batches
self.batch_size = batch_size
self.adapter = adapter
self.epoch = 0
if epochs_per_round == 1:
logging.warning(
"Using `RoundsDataset` with `epochs_per_round=1` is equivalent to fully online training. "
"Use an `OnlineDataset` instead for best performance."
)
self.epochs_per_round = epochs_per_round
self.simulator = simulator
self.regenerate()
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
"""Get a batch of pre-simulated data"""
batch = self.batches[item]
if self.adapter is not None:
batch = self.adapter(batch)
return batch
@property
def num_batches(self) -> int:
return self._num_batches
[docs]
def on_epoch_end(self) -> None:
self.epoch += 1
if self.epoch % self.epochs_per_round == 0:
self.regenerate()
[docs]
def regenerate(self) -> None:
"""Sample new batches of data from the joint distribution unconditionally"""
self.batches = [self.simulator.sample((self.batch_size,)) for _ in range(self.batches_per_epoch)]