Source code for bayesflow.datasets.offline_dataset

import keras
import numpy as np

from bayesflow.adapters import Adapter
from bayesflow.utils import logging


[docs] class OfflineDataset(keras.utils.PyDataset): """ A dataset that is pre-simulated and stored in memory. When storing and loading data from disk, it is recommended to save any pre-simulated data in raw form and create the `OfflineDataset` object only after loading in the raw data. See the `DiskDataset` class for handling large datasets that are split into multiple smaller files. """ def __init__( self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, num_samples: int = None, **kwargs ): super().__init__(**kwargs) self.batch_size = batch_size self.data = data self.adapter = adapter if num_samples is None: self.num_samples = self._get_num_samples_from_data(data) logging.debug(f"Automatically determined {self.num_samples} samples in data.") else: self.num_samples = num_samples self.indices = np.arange(self.num_samples, dtype="int64") self.shuffle() def __getitem__(self, item: int) -> dict[str, np.ndarray]: """Get a batch of pre-simulated data""" if not 0 <= item < self.num_batches: raise IndexError(f"Index {item} is out of bounds for dataset with {self.num_batches} batches.") item = slice(item * self.batch_size, (item + 1) * self.batch_size) item = self.indices[item] batch = { key: np.take(value, item, axis=0) if isinstance(value, np.ndarray) else value for key, value in self.data.items() } if self.adapter is not None: batch = self.adapter(batch) return batch @property def num_batches(self) -> int | None: return int(np.ceil(self.num_samples / self.batch_size))
[docs] def on_epoch_end(self) -> None: self.shuffle()
[docs] def shuffle(self) -> None: """Shuffle the dataset in-place.""" np.random.shuffle(self.indices)
@staticmethod def _get_num_samples_from_data(data: dict) -> int: for key, value in data.items(): if hasattr(value, "shape"): ndim = len(value.shape) if ndim > 1: return value.shape[0] raise ValueError("Could not determine number of samples from data. Please pass it manually.")