Source code for bayesflow.datasets.ensemble_online_dataset

from collections.abc import Mapping, Sequence
from typing import Any

import math
import numpy as np
import keras

from bayesflow.utils import logging

from .helpers import apply_augmentations
from .helpers import ring_starts, ring_window_indices


[docs] class EnsembleOnlineDataset(keras.utils.PyDataset): def __init__( self, dataset: keras.utils.PyDataset, member_names: Sequence[str], data_reuse: float = 1.0, **kwargs, ): super().__init__(**kwargs) if len(member_names) < 2: raise ValueError("EnsembleOnlineDataset: len(member_names) must be >= 2.") if not (0.0 <= data_reuse <= 1.0): raise ValueError("EnsembleOnlineDataset: data_reuse must be in [0, 1].") for attr in ("simulator", "batch_size", "num_batches", "augmentations", "adapter"): if not hasattr(dataset, attr): raise TypeError(f"EnsembleOnlineDataset: wrapped dataset must expose `{attr}`.") self.dataset = dataset self.member_names = list(member_names) self.ensemble_size = len(self.member_names) self.data_reuse = data_reuse self.batch_size = int(dataset.batch_size) self._num_batches = dataset.num_batches self.reduction_factor = 1 / (data_reuse + (1 - data_reuse) * self.ensemble_size) self.pool_size = int(math.ceil(self.batch_size / self.reduction_factor)) logging.info( f"EnsembleOnlineDataset: ensemble_size={self.ensemble_size}, " f"batch_size={self.batch_size}, data_reuse={self.data_reuse} -> " f"reduction_factor={self.reduction_factor:.1f}, " f"pool_size={self.pool_size} (≈{1 / self.reduction_factor:.1f}*batch_size).\n" "Overlap is enforced per training step by splitting a pooled simulated batch into member windows." ) @property def num_batches(self): return self._num_batches def __len__(self) -> int: return self.num_batches def __getitem__(self, item: int) -> dict[str, dict[str, Any]]: if self.data_reuse == 1.0: batch = self.dataset.simulator.sample(self.batch_size) batch = self._postprocess(batch) return self._replicate(batch) pool = self.dataset.simulator.sample(self.pool_size) pool = self._postprocess(pool) starts = ring_starts(self.pool_size, self.ensemble_size) idx2d = ring_window_indices(self.pool_size, self.batch_size, starts) out = {} for k, name in enumerate(self.member_names): out[name] = self._take(pool, idx2d[k]) return self._flip_nested_dict(out) def _flip_nested_dict(self, d: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]: flipped = {} for key, val in d.items(): for subkey, subval in val.items(): flipped.setdefault(subkey, {}) flipped[subkey][key] = subval return flipped def _postprocess(self, batch: dict[str, Any]) -> dict[str, Any]: batch = apply_augmentations(batch, self.dataset.augmentations) if self.dataset.adapter is not None: batch = self.dataset.adapter(batch) return batch @staticmethod def _take(batch: Mapping[str, Any], idx: np.ndarray) -> dict[str, Any]: out = {} for k, v in batch.items(): out[k] = np.take(v, idx, axis=0) if isinstance(v, np.ndarray) else v return out def _replicate(self, batch: Mapping[str, Any]) -> dict[str, dict[str, Any]]: out = {} for key, value in batch.items(): out[key] = {} for name in self.member_names: out[key][name] = value return out