Source code for bayesflow.datasets.ensemble_indexed_dataset
from collections.abc import Sequence
from typing import Any
import math
import numpy as np
import keras
from bayesflow.utils import logging
from .helpers import ring_starts, ring_window_indices
[docs]
class EnsembleIndexedDataset(keras.utils.PyDataset):
def __init__(
self,
dataset: keras.utils.PyDataset,
member_names: Sequence[str],
data_reuse: float = 1.0,
**kwargs,
):
super().__init__(**kwargs)
if len(member_names) < 2:
raise ValueError("EnsembleIndexedDataset: len(member_names) must be >= 2.")
if not (0.0 <= data_reuse <= 1.0):
raise ValueError("EnsembleIndexedDataset: data_reuse must be in [0, 1].")
for attr in ("batch_size", "num_samples", "get_batch_by_sample_indices"):
if not hasattr(dataset, attr):
raise TypeError(f"EnsembleIndexedDataset: wrapped dataset must expose `{attr}`.")
self.dataset = dataset
self.member_names = list(member_names)
self.ensemble_size = len(member_names)
self.data_reuse = float(data_reuse)
self.batch_size = int(dataset.batch_size)
self.num_samples = int(dataset.num_samples)
self.reduction_factor = 1 / (data_reuse + (1 - data_reuse) * self.ensemble_size)
self.window_size = int(math.ceil(self.num_samples * self.reduction_factor))
self.steps_per_epoch = int(math.ceil(self.window_size / self.batch_size))
pool = np.arange(self.num_samples, dtype="int64")
starts = ring_starts(self.num_samples, self.ensemble_size)
idx2d = ring_window_indices(self.num_samples, self.window_size, starts) # (E, W)
self.member_indices = {name: pool[idx2d[k]].copy() for k, name in enumerate(self.member_names)}
self.on_epoch_end()
logging.info(
f"EnsembleIndexedDataset: ensemble_size={self.ensemble_size}, "
f"batch_size={self.batch_size}, num_samples={self.num_samples}, "
f"data_reuse={self.data_reuse} -> "
f"reduction_factor={self.reduction_factor:.2f}, window_size={self.window_size}, "
f"steps_per_epoch={self.steps_per_epoch}. "
"Overlap is enforced at the subdataset level (member-specific windows into the global index pool)."
)
def __len__(self) -> int:
return self.steps_per_epoch
[docs]
def on_epoch_end(self):
if self.data_reuse == 1.0:
np.random.shuffle(self.member_indices[self.member_names[0]])
for name in self.member_names[1:]:
self.member_indices[name] = self.member_indices[self.member_names[0]]
return
# otherwise independent shuffle per member
for name in self.member_names:
np.random.shuffle(self.member_indices[name])
def __getitem__(self, step: int) -> dict[str, dict[str, Any]]:
if not 0 <= step < self.steps_per_epoch:
raise IndexError(f"Index {step} is out of bounds for dataset with {self.steps_per_epoch} steps.")
start = step * self.batch_size
stop = min((step + 1) * self.batch_size, self.window_size) # allow shorter last batch
out: dict[str, dict[str, Any]] = {}
for name in self.member_names:
idx = self.member_indices[name][start:stop]
out[name] = self.dataset.get_batch_by_sample_indices(idx)
return self._flip_nested_dict(out)
def _flip_nested_dict(self, d: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
flipped = {}
for key, val in d.items():
for subkey, subval in val.items():
flipped.setdefault(subkey, {})
flipped[subkey][key] = subval
return flipped