Source code for bayesflow.datasets.rounds_dataset

import keras
import numpy as np

from bayesflow.adapters import Adapter
from bayesflow.simulators.simulator import Simulator
from bayesflow.utils import logging


[docs] class RoundsDataset(keras.utils.PyDataset): """ A dataset that is generated on-the-fly at the beginning of every n-th epoch. """ def __init__( self, simulator: Simulator, batch_size: int, num_batches: int, epochs_per_round: int, adapter: Adapter | None, **kwargs, ): super().__init__(**kwargs) self.batches = None self._num_batches = num_batches self.batch_size = batch_size self.adapter = adapter self.epoch = 0 if epochs_per_round == 1: logging.warning( "Using `RoundsDataset` with `epochs_per_round=1` is equivalent to fully online training. " "Use an `OnlineDataset` instead for best performance." ) self.epochs_per_round = epochs_per_round self.simulator = simulator self.regenerate() def __getitem__(self, item: int) -> dict[str, np.ndarray]: """Get a batch of pre-simulated data""" batch = self.batches[item] if self.adapter is not None: batch = self.adapter(batch) return batch @property def num_batches(self) -> int: return self._num_batches
[docs] def on_epoch_end(self) -> None: self.epoch += 1 if self.epoch % self.epochs_per_round == 0: self.regenerate()
[docs] def regenerate(self) -> None: """Sample new batches of data from the joint distribution unconditionally""" self.batches = [self.simulator.sample((self.batch_size,)) for _ in range(self.batches_per_epoch)]