Source code for bayesflow.datasets.offline_dataset
import keras
import numpy as np
from bayesflow.adapters import Adapter
from bayesflow.utils import logging
[docs]
class OfflineDataset(keras.utils.PyDataset):
"""
A dataset that is pre-simulated and stored in memory. When storing and loading data from disk, it is recommended to
save any pre-simulated data in raw form and create the `OfflineDataset` object only after loading in the raw data.
See the `DiskDataset` class for handling large datasets that are split into multiple smaller files.
"""
def __init__(
self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, num_samples: int = None, **kwargs
):
super().__init__(**kwargs)
self.batch_size = batch_size
self.data = data
self.adapter = adapter
if num_samples is None:
self.num_samples = self._get_num_samples_from_data(data)
logging.debug(f"Automatically determined {self.num_samples} samples in data.")
else:
self.num_samples = num_samples
self.indices = np.arange(self.num_samples, dtype="int64")
self.shuffle()
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
"""Get a batch of pre-simulated data"""
if not 0 <= item < self.num_batches:
raise IndexError(f"Index {item} is out of bounds for dataset with {self.num_batches} batches.")
item = slice(item * self.batch_size, (item + 1) * self.batch_size)
item = self.indices[item]
batch = {
key: np.take(value, item, axis=0) if isinstance(value, np.ndarray) else value
for key, value in self.data.items()
}
if self.adapter is not None:
batch = self.adapter(batch)
return batch
@property
def num_batches(self) -> int | None:
return int(np.ceil(self.num_samples / self.batch_size))
[docs]
def on_epoch_end(self) -> None:
self.shuffle()
[docs]
def shuffle(self) -> None:
"""Shuffle the dataset in-place."""
np.random.shuffle(self.indices)
@staticmethod
def _get_num_samples_from_data(data: dict) -> int:
for key, value in data.items():
if hasattr(value, "shape"):
ndim = len(value.shape)
if ndim > 1:
return value.shape[0]
raise ValueError("Could not determine number of samples from data. Please pass it manually.")