Source code for bayesflow.datasets.ensemble_dataset
from collections.abc import Sequence
from typing import Any
import keras
from .ensemble_online_dataset import EnsembleOnlineDataset
from .ensemble_indexed_dataset import EnsembleIndexedDataset
[docs]
class EnsembleDataset(keras.utils.PyDataset):
"""
Wrap a BayesFlow dataset to provide per-ensemble-member batches.
This dataset class is the recommended entry point for training ensembles.
The wrapped dataset should meet the requirements of any single approximator in
the :class:`~bayesflow.approximators.EnsembleApproximator`. `EnsembleDataset` supports
:class:`~bayesflow.datasets.OnlineDataset`, :class:`~bayesflow.datasets.OfflineDataset`,
and :class:`~bayesflow.datasets.DiskDataset` and returns a key-value pair for each
ensemble member, containing output of the same structure as the wrapped dataset.
The wrapper controls how much data is shared between ensemble members through the
``data_reuse`` parameter:
- ``data_reuse = 1.0``: all ensemble members receive identical data.
- ``data_reuse = 0.0``: each member receives maximally different data.
- intermediate values: the total amount of data used per step / per epoch interpolates
linearly between these extremes.
Notes
-----
Implementation details differ by dataset type:
**OnlineDataset**
A larger "pool" of simulations is generated per training step and split into
overlapping member batches (sharing is enforced per batch).
This is implemented by :class:`~bayesflow.datasets.EnsembleOnlineDataset`.
**OfflineDataset / DiskDataset**
A member-specific subdataset (window into the full index set) is constructed once
on initialization. Batches are drawn from these subdatasets and reshuffled on
``on_epoch_end`` (sharing is enforced at the subdataset level).
This is implemented by :class:`~bayesflow.datasets.EnsembleIndexedDataset`.
Parameters
----------
dataset : keras.utils.PyDataset
A BayesFlow dataset (OnlineDataset, OfflineDataset, DiskDataset).
member_names: Sequence[str]
Names of ensemble members, used as dictionary keys.
data_reuse : float, default=1.0
Degree of independence between ensemble members in ``[0, 1]``.
See Notes for how it is applied for different dataset types.
"""
def __init__(
self,
dataset: keras.utils.PyDataset,
member_names: Sequence[str],
data_reuse: float = 1.0,
**kwargs,
):
super().__init__(**kwargs)
# Dispatch based on capabilities (duck typing)
if hasattr(dataset, "simulator") and hasattr(dataset, "num_batches"):
self._wrapped = EnsembleOnlineDataset(
dataset,
member_names=member_names,
data_reuse=data_reuse,
)
elif hasattr(dataset, "get_batch_by_sample_indices") and hasattr(dataset, "num_samples"):
self._wrapped = EnsembleIndexedDataset(
dataset,
member_names=member_names,
data_reuse=data_reuse,
)
else:
raise TypeError(
"EnsembleDataset: dataset must be OnlineDataset-like (has `.simulator`) "
"or Offline/Disk-like (has `num_samples` and `get_batch_by_sample_indices`)."
)
@property
def num_batches(self) -> int:
# provide a consistent attribute if the impl has it, else fall back to __len__
return int(getattr(self._wrapped, "num_batches", len(self._wrapped)))
@property
def batch_size(self) -> int:
return self._wrapped.batch_size
def __len__(self) -> int:
return len(self._wrapped)
def __getitem__(self, item: int) -> dict[str, dict[str, Any]]:
return self._wrapped[item]
[docs]
def on_epoch_end(self):
if hasattr(self._wrapped, "on_epoch_end"):
self._wrapped.on_epoch_end()