from collections.abc import Callable, Mapping, Sequence
from typing import Any
import numpy as np
from scipy.special import logsumexp
import keras
from bayesflow.adapters import Adapter
from bayesflow.simulators import Simulator
from bayesflow.types import Tensor
from bayesflow.utils import logging, filter_kwargs
from bayesflow.utils.serialization import serializable, serialize
from bayesflow.datasets import EnsembleDataset
from .approximator import Approximator
[docs]
@serializable("bayesflow.approximators")
class EnsembleApproximator(Approximator):
"""Combines multiple approximators into a single ensemble.
An ``EnsembleApproximator`` wraps a named collection of
:class:`~bayesflow.approximators.Approximator` instances and trains them
jointly. At inference time it can produce
*per-member* results or *merged* results (weighted mixture for
:meth:`sample` and :meth:`log_prob`).
The adapter is inherited from the first member approximator and is assumed
to be the same across all members (this is **not** enforced).
Parameters
----------
approximators : Mapping[str, Approximator]
A mapping from member names to approximator instances. Each member
is trained on its own slice of the data during :meth:`fit` and is
addressed by name in :meth:`sample`, :meth:`log_prob`, and
:meth:`estimate`.
**kwargs : dict, optional
Additional arguments forwarded to the
:class:`~bayesflow.approximators.Approximator` base class.
"""
def __init__(self, approximators: Mapping[str, Approximator], **kwargs):
super().__init__(**kwargs)
self._warn_if_shared_approximator_components(approximators)
self.approximators = approximators
self.members = tuple(self.approximators.keys())
self.distribution_members = tuple(
k for k, a in self.approximators.items() if getattr(a, "has_distribution", False)
)
self.estimate_members = tuple(k for k, a in self.approximators.items() if hasattr(a, "estimate"))
self.has_distribution = bool(self.distribution_members)
@classmethod
def _warn_if_shared_approximator_components(cls, approximators):
"""Warn if approximators share component instances (not safely serializable yet)."""
tracked = ("inference_network", "summary_network")
seen = {name: {} for name in tracked}
for member_name, approximator in approximators.items():
for attr in tracked:
if not hasattr(approximator, attr):
continue
obj = getattr(approximator, attr)
if obj is None:
continue
obj_id = id(obj)
seen[attr].setdefault(obj_id, []).append(member_name)
# Emit one warning per shared object instance.
for attr, by_id in seen.items():
for members in by_id.values():
if len(members) > 1:
logging.warning(
"EnsembleApproximator contains shared component '{attr}' across members {members}. "
"Deserialization of weights of shared components is not supported yet and may fail. "
"Use separate component instances (e.g., clone networks) to be able to serialize "
"the whole EnsembleApproximator object or serialize the approximators in the ensemble "
"separately.",
attr=attr,
members=members,
)
@classmethod
def _warn_ignored_member_weights(cls, member_weights: Mapping[str, float] | None, merge_members: bool):
if member_weights is not None and not merge_members:
logging.warning(
"`member_weights` is ignored when `merge_members=False`. "
"Set `merge_members=True` to use a weighted mixture."
)
@property
def adapter(self) -> Adapter:
# Defer to any adapter of the approximators,
# assuming all are the same, which is not enforced at the moment.
# self.adapter will only be used when super().fit calls build_dataset(..., adapter=self.adapter).
return next(iter(self.approximators.values())).adapter
[docs]
def build_dataset(
self,
*,
batch_size: int = "auto",
num_batches: int,
adapter: Adapter = "auto",
memory_budget: str | int = "auto",
simulator: Simulator,
workers: int = "auto",
use_multiprocessing: bool = False,
max_queue_size: int = 32,
**kwargs,
) -> EnsembleDataset:
base_dataset = super().build_dataset(
batch_size=batch_size,
num_batches=num_batches,
adapter=adapter,
memory_budget=memory_budget,
simulator=simulator,
workers=workers,
use_multiprocessing=use_multiprocessing,
max_queue_size=max_queue_size,
**kwargs,
)
return EnsembleDataset(
base_dataset=base_dataset,
member_names=self.members,
**filter_kwargs(kwargs, keras.utils.PyDataset.__init__),
)
[docs]
def build(self, data_shapes: dict) -> None:
for approx_name, approximator in self.approximators.items():
_data_shape = {}
for var_name, variable in data_shapes.items(): # variable type
# If data_shapes has a nested ensemble level, select shapes with approx_name.
# Note, summary_variables might be dict, if a FusionNetwork is used.
# Thus, we further check whether the approx_name is in the keys.
if isinstance(variable, dict) and approx_name in variable.keys():
_data_shape[var_name] = variable[approx_name]
else:
_data_shape[var_name] = variable
approximator.build(_data_shape)
[docs]
def fit(self, *args, **kwargs) -> keras.callbacks.History:
"""
Trains the ensemble of approximators on the provided dataset or on-demand data generated
from the given simulator.
If `dataset` is not provided, a dataset is built from the `simulator`.
If the model has not been built, it will be built using a batch from the dataset.
If `dataset` is `OnlineDataset`, `OfflineDataset` or `DiskDataset`,
it will be wrapped into an `EnsembleDataset`.
Parameters
----------
dataset : keras.utils.PyDataset, optional
A dataset containing simulations for training. If provided, `simulator` must be None.
simulator : Simulator, optional
A simulator used to generate a dataset. If provided, `dataset` must be None.
**kwargs
Additional keyword arguments passed to `keras.Model.fit()` and to the dataset constructor
if `dataset` is not provided.
Returns
-------
keras.callbacks.History
A history object containing the training loss and metrics values.
Raises
------
ValueError
If both `dataset` and `simulator` are provided or neither is provided.
"""
return super().fit(*args, **kwargs, adapter=self.adapter)
[docs]
def compute_metrics(
self,
inference_variables: dict[str, Tensor] | Tensor,
inference_conditions: dict[str, Tensor] | Tensor | None = None,
summary_variables: dict[str, Tensor] | Tensor | None = None,
sample_weight: dict[str, Tensor] | Tensor | None = None,
summary_attention_mask: dict[str, Tensor] | Tensor | None = None,
summary_mask: dict[str, Tensor] | Tensor | None = None,
inference_attention_mask: dict[str, Tensor] | Tensor | None = None,
inference_mask: dict[str, Tensor] | Tensor | None = None,
stage: str = "training",
) -> dict[str, dict[str, Tensor]]:
metrics = {}
def select(value, name):
if value is None:
return None
return value[name] if stage == "training" else value
for name, approximator in self.approximators.items():
metrics[name] = approximator.compute_metrics(
inference_variables=select(inference_variables, name),
inference_conditions=select(inference_conditions, name),
summary_variables=select(summary_variables, name),
sample_weight=select(sample_weight, name),
summary_attention_mask=select(summary_attention_mask, name),
summary_mask=select(summary_mask, name),
inference_attention_mask=select(inference_attention_mask, name),
inference_mask=select(inference_mask, name),
stage=stage,
)
metrics = {
f"{approx_name}/{metric_key}": value
for approx_name, approx_metrics in metrics.items()
for metric_key, value in approx_metrics.items()
}
losses = [v for k, v in metrics.items() if "loss" in k]
metrics["loss"] = keras.ops.sum(losses)
return metrics
[docs]
def sample(
self,
*,
num_samples: int,
conditions: Mapping[str, np.ndarray],
split: bool = False,
member_weights: Mapping[str, float] | None = None,
merge_members: bool = True,
**kwargs,
) -> dict[str, np.ndarray]:
"""
Draw samples from the marginalized distribution over ensemble members.
Samples are allocated to approximators via multinomial sampling using member_weights,
then concatenated and shuffled to produce the marginal distribution.
Parameters
----------
num_samples : int
Total number of samples to draw.
conditions : Mapping[str, np.ndarray]
Conditions for sampling.
split : bool, optional
Whether to split output arrays, by default False.
member_weights : Mapping[str, float] or None, optional
Probability weights for each approximator. If None, uses uniform weights.
Must be nonnegative, will be normalized to sum to 1.
merge_members : bool, optional
Whether to merge samples from all approximators into a single (weighted) marginal sample.
**kwargs
Additional arguments passed to approximator.sample().
Returns
-------
dict[str, np.ndarray]
Samples with shape (batch_size, num_samples, ...) for each variable.
"""
self._warn_ignored_member_weights(member_weights, merge_members)
if not merge_members:
return self._map_members(
None,
capability="distribution",
fn=lambda name, a: a.sample(num_samples=num_samples, conditions=conditions, split=split, **kwargs),
)
weights = self._resolve_member_weights(member_weights)
names = tuple(weights.keys())
probs = np.fromiter(weights.values(), dtype=float, count=len(weights))
counts = np.random.multinomial(num_samples, probs)
alloc = {name: int(count) for name, count in zip(names, counts) if count > 0}
per_member = self._map_members(
list(alloc.keys()),
capability="distribution",
fn=lambda name, a: a.sample(num_samples=alloc[name], conditions=conditions, split=split, **kwargs),
)
merged = keras.tree.map_structure(lambda *xs: np.concatenate(xs, axis=1), *list(per_member.values()))
idx = np.random.permutation(num_samples)
return keras.tree.map_structure(lambda a: np.take(a, idx, axis=1), merged)
[docs]
def log_prob(
self,
data: Mapping[str, np.ndarray],
member_weights: Mapping[str, float] | None = None,
merge_members: bool = True,
**kwargs,
) -> np.ndarray:
"""
Compute the marginalized log probability over ensemble members.
Uses log-sum-exp trick to compute log p(x) = log(sum_i w_i * p_i(x)).
Parameters
----------
data : Mapping[str, np.ndarray]
Data containing inference variables and conditions.
member_weights : Mapping[str, float] or None, default None
Probability weights for each approximator. If None, uses uniform weights.
Must be nonnegative, will be normalized to sum to 1.
merge_members : bool, optional
Whether to merge log probabilities from all approximators into a single marginal log probability.
**kwargs
Additional arguments passed to approximator.log_prob().
Returns
-------
np.ndarray
Marginalized log probabilities with shape (batch_size,).
"""
self._warn_ignored_member_weights(member_weights, merge_members)
if not merge_members:
return self._map_members(
None,
capability="distribution",
fn=lambda name, a: a.log_prob(data=data, **kwargs),
)
weights = self._resolve_member_weights(member_weights)
members = list(weights.keys())
log_probs = self._map_members(
members,
capability="distribution",
fn=lambda name, a: a.log_prob(data=data, **kwargs),
)
stacked = np.stack([log_probs[m] for m in members], axis=-1)
log_w = np.log(np.fromiter((weights[m] for m in members), dtype=float, count=len(members)))
return logsumexp(stacked + log_w, axis=-1)
[docs]
def estimate(
self,
conditions: Mapping[str, np.ndarray],
*,
members: Sequence[str] | None = None,
split: bool = False,
groupby: str = "member",
**kwargs,
) -> dict:
"""
Compute point estimates and distribution parameters from each approximator separately.
Parameters
----------
conditions : Mapping[str, np.ndarray]
Conditions for estimation.
members : Sequence[str] or None, default None
Ensemble members to estimate with.
If None, will estimate with all members that have an `estimate` method.
split : bool, optional
Whether to split output arrays, by default False.
groupby : {"member", "variable"}, default "member"
Controls the top-level nesting of the returned dictionary.
- "member": return estimates as ``member -> variable -> score (-> head) -> array``.
- "variable": return estimates as ``variable -> score (-> head) -> member -> array``.
See also :py:meth:`~bayesflow.ScoringRuleApproximator.estimate`.
**kwargs
Additional arguments passed to approximator.estimate().
Returns
-------
dict[str, dict[str, dict[str, np.ndarray]]]
Estimates keyed by approximator name, then by variable and score names.
"""
estimates = self._map_members(
members,
capability="estimate",
fn=lambda name, a: a.estimate(conditions=conditions, split=split, **kwargs),
)
if groupby == "member":
return estimates
elif groupby == "variable":
out = {}
for member_key, member_est in estimates.items():
for var_key, var_est in member_est.items():
out.setdefault(var_key, {})
for score_key, score_val in var_est.items():
# score has heads -> dict[head] = array
if isinstance(score_val, dict):
node = out[var_key].setdefault(score_key, {})
if not isinstance(node, dict):
raise ValueError(
f"Inconsistent estimate structure for variable={var_key!r}, score={score_key!r}: "
"some members return a dict of heads, others return an array."
)
for head_key, arr in score_val.items():
node.setdefault(head_key, {})
node[head_key][member_key] = arr
# score is already an array (no head level / squeezed)
else:
node = out[var_key].setdefault(score_key, {})
if isinstance(node, dict) and node and any(isinstance(v, dict) for v in node.values()):
raise ValueError(
f"Inconsistent estimate structure for variable={var_key!r}, score={score_key!r}: "
"some members return an array, others return a dict of heads."
)
# keep head level absent; attach member at score level
if not isinstance(node, dict):
# should not happen, but keep it safe
out[var_key][score_key] = {}
node = out[var_key][score_key]
node[member_key] = score_val
return out
else:
raise NotImplementedError(
f"`groupby={groupby!r}` is not supported for EnsembleApproximator. Use 'member' or 'variable'."
)
def _map_members(
self,
members: Sequence[str] | None,
*,
capability: str,
fn: Callable,
) -> dict[str, Any]:
resolved = self._resolve_members(members, capability=capability)
return {name: fn(name, self.approximators[name]) for name in resolved}
def _resolve_members(self, members: Sequence[str] | None, *, capability: str) -> tuple[str, ...]:
if capability == "any":
base = self.members
elif capability == "distribution":
base = self.distribution_members
elif capability == "estimate":
base = self.estimate_members
else:
raise ValueError(f"Unknown capability {capability!r}")
if members is None:
return base
base_set = set(base)
members_t = tuple(members)
unknown = [m for m in members_t if m not in base_set]
if unknown:
raise ValueError(f"Unknown/unsupported members for capability={capability!r}: {unknown}")
return members_t
def _resolve_member_weights(self, member_weights: Mapping[str, float] | None) -> Mapping[str, float]:
if member_weights is None:
member_weights = {k: 1.0 for k in self.distribution_members}
for key, weight in member_weights.items():
if key not in self.distribution_members:
raise ValueError(
"Member weights must be subset of self.distribution_members. "
f"Unknown keys: {set(member_weights) - set(self.distribution_members)}"
)
if weight < 0:
raise ValueError(f"All member_weights must be positive. Received {key}: {weight}.")
# Normalize weights to 1
summed = np.sum(list(member_weights.values()))
member_weights = {k: v / summed for k, v in member_weights.items()}
return member_weights
def _batch_size_from_data(self, data: Mapping[str, Mapping[str, Any]]) -> int:
"""
Fetches the current batch size from an input dictionary. Can only be used during training when
inference variables as present.
"""
if isinstance(data["inference_variables"], dict):
return keras.ops.shape(data["inference_variables"][list(self.approximators.keys())[0]])[0]
return keras.ops.shape(data["inference_variables"])[0]
[docs]
def get_config(self):
base_config = super().get_config()
config = {"approximators": self.approximators}
return base_config | serialize(config)
[docs]
def build_from_config(self, config):
# the approximators are already built
pass