Source code for bayesflow.datasets.disk_dataset

from collections.abc import Callable, Mapping, Sequence

import os
import pathlib as pl

import numpy as np
import keras

from bayesflow.adapters import Adapter
from bayesflow.utils import tree_stack, pickle_load

from .helpers import apply_augmentations


[docs] class DiskDataset(keras.utils.PyDataset): """A dataset that loads pre-simulated files from disk for offline training. By default, the expected file structure is as follows:: root ├── ... ├── sample_1.[ext] ├── ... └── sample_n.[ext] where each file contains a complete sample (e.g., a dictionary of NumPy arrays) or is converted into a complete sample using a custom loader function. Parameters ---------- root : os.PathLike Root directory containing the sample files. pattern : str, default="*.pkl" Glob pattern to match sample files. batch_size : int Number of samples per batch. load_fn : Callable, optional Function to load a single file into a sample. Defaults to ``pickle_load``. adapter : Adapter or None Optional adapter to transform the loaded batch. augmentations : Callable or Mapping[str, Callable] or Sequence[Callable], optional A single augmentation function, dictionary of augmentation functions, or sequence of augmentation functions to apply to the batch. If you provide a dictionary of functions, each function should accept one element of your output batch and return the corresponding transformed element. Otherwise, your function should accept the entire dictionary output and return a dictionary. Note: augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training. shuffle : bool, optional Whether to shuffle the dataset at initialization and at the end of each epoch. Default is ``True``. **kwargs Additional keyword arguments passed to the base ``PyDataset``. """ def __init__( self, root: os.PathLike, *, pattern: str = "*.pkl", batch_size: int, load_fn: Callable = None, adapter: Adapter | None, augmentations: Callable | Mapping[str, Callable] | Sequence[Callable] = None, shuffle: bool = True, **kwargs, ): super().__init__(**kwargs) self.batch_size = batch_size self.root = pl.Path(root) self.load_fn = load_fn or pickle_load self.adapter = adapter self.files = list(map(str, self.root.glob(pattern))) self.num_samples = len(self.files) self.indices = np.arange(self.num_samples, dtype="int64") self.augmentations = augmentations or [] self._shuffle = shuffle if self._shuffle: self.shuffle() def __getitem__(self, item) -> dict[str, np.ndarray]: if not 0 <= item < self.num_batches: raise IndexError(f"Index {item} is out of bounds for dataset with {self.num_batches} batches.") start = item * self.batch_size stop = min((item + 1) * self.batch_size, self.num_samples) idx = self.indices[start:stop] return self.get_batch_by_sample_indices(idx)
[docs] def get_batch_by_sample_indices(self, indices: np.ndarray) -> dict[str, np.ndarray]: selected_files = [self.files[int(i)] for i in indices] batch = [self.load_fn(file) for file in selected_files] batch = tree_stack(batch) batch = apply_augmentations(batch, self.augmentations) if self.adapter is not None: batch = self.adapter(batch) return batch
[docs] def on_epoch_end(self): if self._shuffle: self.shuffle()
@property def num_batches(self): return int(np.ceil(self.num_samples / self.batch_size)) def __len__(self) -> int: return self.num_batches
[docs] def shuffle(self): np.random.shuffle(self.indices)