Source code for bayesflow.datasets.offline_dataset

from collections.abc import Callable, Mapping, Sequence

import numpy as np
import keras

from bayesflow.adapters import Adapter
from bayesflow.utils import logging

from .helpers import apply_augmentations


[docs] class OfflineDataset(keras.utils.PyDataset): """A dataset that is pre-simulated and stored in memory. When storing and loading data from disk, it is recommended to save any pre-simulated data in raw form and create the ``OfflineDataset`` object only after loading in the raw data. See :class:`DiskDataset` for handling large datasets that are split into multiple smaller files. Parameters ---------- data : Mapping[str, np.ndarray] Pre-simulated data stored in a dictionary, where each key maps to a NumPy array. batch_size : int Number of samples per batch. adapter : Adapter or None Optional adapter to transform the batch. num_samples : int, optional Number of samples in the dataset. If ``None``, it will be inferred from the data. augmentations : Callable or Mapping[str, Callable] or Sequence[Callable], optional A single augmentation function, dictionary of augmentation functions, or sequence of augmentation functions to apply to the batch. If you provide a dictionary of functions, each function should accept one element of your output batch and return the corresponding transformed element. Otherwise, your function should accept the entire dictionary output and return a dictionary. Note: augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training. shuffle : bool, optional Whether to shuffle the dataset at initialization and at the end of each epoch. Default is ``True``. **kwargs Additional keyword arguments passed to the base ``PyDataset``. """ def __init__( self, data: Mapping[str, np.ndarray], batch_size: int, adapter: Adapter | None, num_samples: int = None, *, augmentations: Callable | Mapping[str, Callable] | Sequence[Callable] = None, shuffle: bool = True, **kwargs, ): super().__init__(**kwargs) self.batch_size = batch_size self.data = data self.adapter = adapter if num_samples is None: self.num_samples = self._get_num_samples_from_data(data) logging.debug(f"Automatically determined {self.num_samples} samples in data.") else: self.num_samples = num_samples self.indices = np.arange(self.num_samples, dtype="int64") self.augmentations = augmentations or [] self._shuffle = shuffle if self._shuffle: self.shuffle() @property def num_batches(self) -> int: return int(np.ceil(self.num_samples / self.batch_size)) def __getitem__(self, item: int) -> dict[str, np.ndarray]: """ Load a batch of data from disk. Parameters ---------- item : int Index of the batch to retrieve. Returns ------- dict of str to np.ndarray A batch of loaded (and optionally augmented/adapted) data. Raises ------ IndexError If the requested batch index is out of range. """ if not 0 <= item < self.num_batches: raise IndexError(f"Index {item} is out of bounds for dataset with {self.num_batches} batches.") start = item * self.batch_size stop = min((item + 1) * self.batch_size, self.num_samples) idx = self.indices[start:stop] return self.get_batch_by_sample_indices(idx)
[docs] def get_batch_by_sample_indices(self, indices: np.ndarray) -> dict[str, np.ndarray]: """ Return a batch for explicit sample indices. This method is the index-based access primitive used by ensemble dataset wrappers. It selects samples from the underlying in-memory arrays, then applies augmentations and the adapter just like in :meth:`__getitem__`. Parameters ---------- indices : np.ndarray 1D integer array of sample indices in the range ``[0, num_samples)``. The returned batch will have leading dimension ``len(indices)``. Returns ------- dict of str to np.ndarray A batch dictionary where each NumPy array has shape ``(len(indices), ...)``. Non-array entries are passed through unchanged. """ batch = { key: np.take(value, indices, axis=0) if isinstance(value, np.ndarray) else value for key, value in self.data.items() } batch = apply_augmentations(batch, self.augmentations) if self.adapter is not None: batch = self.adapter(batch) return batch
def __len__(self) -> int: return self.num_batches
[docs] def on_epoch_end(self) -> None: if self._shuffle: self.shuffle()
[docs] def shuffle(self) -> None: """Shuffle the dataset in-place.""" np.random.shuffle(self.indices)
@staticmethod def _get_num_samples_from_data(data: Mapping) -> int: for key, value in data.items(): if hasattr(value, "shape"): ndim = len(value.shape) if ndim > 1: return value.shape[0] raise ValueError("Could not determine number of samples from data. Please pass it manually.")