from collections.abc import Mapping, Sequence
import numpy as np
from scipy.special import logsumexp
import keras
from bayesflow.adapters import Adapter
from bayesflow.networks import ScoringRuleNetwork, SummaryNetwork
from bayesflow.types import Tensor
from bayesflow.utils import (
logging,
split_arrays,
squeeze_inner_estimates_dict,
)
from bayesflow.utils.serialization import serializable
from .continuous_approximator import ContinuousApproximator
[docs]
@serializable("bayesflow.approximators")
class ScoringRuleApproximator(ContinuousApproximator):
"""
A workflow for fast amortized Bayes risk minimization for arbitrary scoring rules.
Inherits from :class:`ContinuousApproximator` and adapts the sample, log_prob, and estimate
interfaces for the nested output structure of :class:`~bayesflow.networks.ScoringRuleNetwork`.
Parameters
----------
inference_network : InferenceNetwork
The inference network used for point estimation.
adapter : bayesflow.adapters.Adapter, optional
Adapter for data processing. You can use :py:meth:`build_adapter`
to create it. If ``None`` (default), an identity adapter is used
that makes a shallow copy and passes data through unchanged.
summary_network : SummaryNetwork, optional
The summary network used for data summarization (default is None).
standardize : str | Sequence[str] | None
The variables to standardize before passing to the networks. Can be either
"all" or any subset of ["inference_variables", "summary_variables", "inference_conditions"].
(default is "inference_variables").
**kwargs : dict, optional
Additional arguments passed to the :py:class:`ContinuousApproximator` class.
"""
def __init__(
self,
*,
inference_network: ScoringRuleNetwork,
adapter: Adapter = None,
summary_network: SummaryNetwork | None = None,
standardize: str | Sequence[str] | None = "inference_variables",
**kwargs,
):
super().__init__(
adapter=adapter,
inference_network=inference_network,
summary_network=summary_network,
standardize=standardize,
**kwargs,
)
self.distribution_keys = []
for score_key, score in self.inference_network.scoring_rules.items():
has_sample = callable(getattr(score, "sample", None))
has_log_prob = callable(getattr(score, "log_prob", None))
if has_sample and has_log_prob:
self.distribution_keys.append(score_key)
self.has_distribution = len(self.distribution_keys) > 0
[docs]
def estimate(
self,
conditions: Mapping[str, np.ndarray] | None = None,
split: bool = False,
groupby: str = "variable",
**kwargs,
) -> dict[str, dict[str, np.ndarray | dict[str, np.ndarray]]]:
"""
Estimate point summaries and distributional parameters of inference variables induced by
scoring rules.
Parameters
----------
conditions : Mapping[str, np.ndarray]
A batch of conditions to estimate. None for unconditional distributions.
split : bool, optional
If True, split estimated arrays along the last axis, by default False.
groupby : {"variable", "score"}, default "variable"
Controls the top-level nesting of the returned dictionary.
- "variable": return estimates as ``variable -> score -> head -> array``.
If a score has just one head, which is called value, squeeze the unnecessary nesting.
- "score": return estimates as ``score -> head -> variable -> array``.
**kwargs
Additional keyword arguments passed to underlying processing functions.
Returns
-------
dict
Estimates in the requested layout. Leafs have shape
``(num_conditions, point_estimate_size, variable_block_size)`` or
``(point_estimate_size, variable_block_size)`` if ``conditions=None``.
"""
estimates = self._estimate_byscore(conditions=conditions, split=split, **kwargs)
if groupby == "score":
return estimates
# Reorder the nested dictionary so that original variable names are at the top.
estimates = self._reorder_estimates(estimates)
estimates = self._squeeze_estimates(estimates)
return estimates
def _estimate_byscore(
self,
conditions: Mapping[str, np.ndarray],
split: bool = False,
**kwargs,
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
resolved_conditions, adapted, _ = self._prepare_conditions(conditions)
inference_kwargs = kwargs | self._collect_mask_kwargs(self._INFERENCE_MASK_KEYS, adapted)
estimates = self.inference_network(
conditions=resolved_conditions,
**inference_kwargs,
)
# Unstandardize the network outputs
for score_key, score in self.inference_network.scoring_rules.items():
for head_key in estimates[score_key].keys():
transformation_type = score.TRANSFORMATION_TYPE.get(head_key, "location_scale")
estimates[score_key][head_key] = self.standardizer.maybe_standardize(
estimates[score_key][head_key],
key="inference_variables",
stage="inference",
forward=False,
transformation_type=transformation_type,
)
estimates = self._apply_inverse_adapter_to_estimates(estimates, **kwargs)
if split:
estimates = split_arrays(estimates, axis=-1)
return estimates
[docs]
def sample(
self,
*,
num_samples: int,
conditions: Mapping[str, np.ndarray] | None = None,
split: bool = False,
score_weights: Mapping[str, float] | None = None,
merge_scores: bool = True,
**kwargs,
) -> dict[str, np.ndarray] | dict[str, dict[str, np.ndarray]]:
"""
Draw samples from the parametric distributions induced by the configured scoring rules.
This method supports two modes:
1) ``merge_scores=True`` (default):
Samples are drawn from the *mixture* over scoring rules.
The requested ``num_samples`` are allocated across scores according to ``score_weights``
(uniform if None), drawn for each score, and merged into a single shuffled set.
2) ``merge_scores=False``:
Samples are drawn *separately* for each scoring rule, returning a nested dictionary
keyed by score name. In this mode, ``num_samples`` samples are generated per score.
Parameters
----------
num_samples : int
Number of samples to draw. If ``merge_scores=True``, this is the total number of
mixture samples returned. If ``merge_scores=False``, this is the number of samples
generated per scoring rule.
conditions : Mapping[str, np.ndarray] or None,
A batch of conditions to sample for. None for unconditional distributions.
split : bool, optional
Whether to split the output arrays along the last axis. Delegated to :meth:`sample_separate`.
score_weights : Mapping[str, float] or None, default None
Probability weights for each scoring rule. If ``None``, uniform weights are assumed.
Must be positive, will be normalized to sum to 1. Only used when ``merge_scores=True``.
merge_scores : bool, default True
If True, return samples aggregated across scoring rules as a mixture.
If False, return samples separately for each scoring rule.
**kwargs
Additional keyword arguments such as ``batch_size``.
Returns
-------
samples : dict[str, np.ndarray] or dict[str, dict[str, np.ndarray]]
If ``merge_scores=True``:
A dictionary keyed by inference variable name. Entries have shape
``(num_datasets, num_samples, ...)``.
If ``merge_scores=False``:
A nested dictionary where the first-level key is the score name and the second-level
key is the inference variable name. Each leaf array has shape
``(num_datasets, num_samples, ...)``.
Raises
------
NotImplementedError
If ``split=True`` is not supported by the approximator implementation.
"""
self._check_has_distribution()
if not merge_scores:
if score_weights is not None:
logging.warning(
"`score_weights` is ignored when `merge_scores=False`. "
"Set `merge_scores=True` to sample from the weighted mixture."
)
return self._sample_separate(num_samples=num_samples, conditions=conditions, split=split, **kwargs)
score_weights = self._resolve_score_weights(score_weights)
# Single score: _sample_separate already squeezed to a plain result,
# and mixing with uniform weight is an identity operation.
if len(self.distribution_keys) == 1:
return self._sample_separate(num_samples=num_samples, conditions=conditions, split=split, **kwargs)
# Allocate samples per score and draw only as many as needed (max over scores).
num_samples_per_score = np.random.multinomial(num_samples, list(score_weights.values()))
max_k = int(np.max(num_samples_per_score))
num_samples_per_score = {k: num_samples_per_score[i] for i, k in enumerate(score_weights.keys())}
samples_by_score = self._sample_separate(num_samples=max_k, conditions=conditions, split=split, **kwargs)
# Crop each score's samples down to its allocated k
cropped_list = []
for score_key, k in num_samples_per_score.items():
if k == 0:
continue
cropped = keras.tree.map_structure(lambda arr: arr[:, :k], samples_by_score[score_key])
cropped_list.append(cropped)
# Concatenate across scores along the sample axis.
concatenated = keras.tree.map_structure(
lambda *arrays: np.concatenate(arrays, axis=1),
*cropped_list,
)
# Shuffle along the sample axis (1) to form the mixture samples.
shuffle_idx = np.random.permutation(num_samples)
shuffled = keras.tree.map_structure(lambda arr: np.take(arr, shuffle_idx, axis=1), concatenated)
return shuffled
def _sample_separate(
self,
*,
num_samples: int,
conditions: Mapping[str, np.ndarray] | None = None,
split: bool = False,
batch_size: int | None = None,
**kwargs,
) -> dict[str, np.ndarray | dict[str, np.ndarray]]:
"""
Draws samples from a parametric distribution based on point estimates.
Uses :meth:`ContinuousApproximator.sample` for condition resolution, sampling,
unstandardization, and inverse adapter, then squeezes the nested score-major
output structure.
Parameters
----------
num_samples : int
The number of samples to generate.
conditions : Mapping[str, np.ndarray]
A dictionary mapping variable names to arrays representing the conditions.
split : bool, optional
If True, the sampled arrays are split along the last axis, by default False.
Currently not supported for :py:class:`ScoringRuleApproximator`.
batch_size : int or None, optional
If provided, the conditions are split into batches of size `batch_size`,
for which samples are generated sequentially.
**kwargs
Additional keyword arguments passed to underlying processing functions.
Returns
-------
samples : dict[str, dict[str, np.ndarray]]
Samples for all inference variables and all parametric scoring rules in a nested dictionary.
Shape: (num_datasets, num_samples, variable_block_size).
"""
if split:
raise NotImplementedError("split=True is currently not supported for `ScoringRuleApproximator`.")
self._check_has_distribution()
# Delegate to parent for condition resolution, sampling, unstandardization, and inverse adapter
samples = super().sample(
num_samples=num_samples,
conditions=conditions,
split=False,
batch_size=batch_size,
**kwargs,
)
return self._squeeze_parametric_score_major_dict(samples)
[docs]
def log_prob(
self,
data: Mapping[str, np.ndarray],
score_weights: Mapping[str, float] | None = None,
merge_scores: bool = True,
**kwargs,
) -> np.ndarray | dict[str, np.ndarray]:
"""
Compute log-probabilities under the parametric distribution(s) induced by the scoring rules.
This method supports two modes:
1) ``merge_scores=True`` (default):
Return the marginalized (mixture) log-probability across scoring rules using ``score_weights``:
``log p(x) = log sum_s w_s p_s(x)``.
2) ``merge_scores=False``:
Return log-probabilities separately for each scoring rule as a dictionary keyed by score name.
In this mode, ``score_weights`` is ignored.
Parameters
----------
data : Mapping[str, np.ndarray]
Dictionary containing inference variables and conditions.
score_weights : Mapping[str, float] or None, default None
Probability weights for each scoring rule. If ``None``, uniform weights are assumed.
Must be positive, will be normalized to sum to 1. Only used when ``merge_scores=True``.
merge_scores : bool, default True
If True, return marginalized log-probabilities across scoring rules.
If False, return per-score log-probabilities.
**kwargs
Additional keyword arguments.
Returns
-------
np.ndarray or dict[str, np.ndarray]
If ``merge_scores=True``: array of shape ``(num_datasets,)`` with marginalized log-probabilities.
If ``merge_scores=False``: dictionary mapping score name to an array of shape ``(num_datasets,)``.
"""
log_probs = self._log_prob_separate(data=data, **kwargs)
if not merge_scores:
if score_weights is not None:
logging.warning(
"`score_weights` is ignored when `merge_scores=False`. "
"Set `merge_scores=True` to compute the weighted mixture log-probability."
)
return log_probs
# Single score: _log_prob_separate already squeezed to a plain array,
# and merging with uniform weight is an identity operation.
if len(self.distribution_keys) == 1:
return np.asarray(log_probs)
score_weights = self._resolve_score_weights(score_weights)
stacked = np.stack([np.asarray(log_probs[score_key]) for score_key in self.distribution_keys], axis=-1)
log_weights = np.log(list(score_weights.values()))
# stacked: (num_datasets, num_scores), log_weights: (num_scores,)
z = stacked + log_weights # broadcasted to (num_datasets, num_scores)
# stable logsumexp over last axis
return logsumexp(z, axis=-1)
def _log_prob_separate(self, data: Mapping[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
"""
Computes the log-probability of given data under the parametric distribution(s).
Parameters
----------
data : dict[str, np.ndarray]
A dictionary mapping variable names to arrays representing the data.
**kwargs
Additional keyword arguments passed to underlying processing functions.
Returns
-------
log_prob : np.ndarray or dict[str, np.ndarray]
Log-probabilities of the distribution for all parametric scoring rules.
If only one parametric score is available, returns an array.
Otherwise, returns a dictionary with score names as keys.
Shape: (num_datasets,)
"""
log_prob = super().log_prob(data, **kwargs)
return self._squeeze_parametric_score_major_dict(log_prob)
def _check_has_distribution(self):
if not self.has_distribution:
raise ValueError("No parametric distribution scores available for sample/log_prob.")
def _resolve_score_weights(
self,
score_weights: Mapping[str, float] | None,
) -> np.ndarray:
if score_weights is None:
score_weights = {k: 1.0 for k in self.distribution_keys}
for key, weight in score_weights.items():
if key not in self.distribution_keys:
raise ValueError(
"Score weights must be subset of self.distribution_keys. "
f"Unknown keys: {set(score_weights) - set(self.distribution_keys)}"
)
if weight < 0:
raise ValueError(f"All score_weights must be positive. Received {key}: {weight}.")
# Normalize weights to 1
sum = np.sum(list(score_weights.values()))
score_weights = {k: v / sum for k, v in score_weights.items()}
return score_weights
def _apply_inverse_adapter_to_estimates(
self, estimates: Mapping[str, Mapping[str, Tensor]], **kwargs
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
"""Applies the inverse adapter on each inner element of the _estimate output dictionary."""
estimates = keras.tree.map_structure(keras.ops.convert_to_numpy, estimates)
processed = {}
for score_key, score_val in estimates.items():
processed[score_key] = {}
for head_key, estimate in score_val.items():
if head_key in self.inference_network.scoring_rules[score_key].NOT_TRANSFORMING_LIKE_VECTOR_WARNING:
logging.warning(
f"Estimate '{score_key}.{head_key}' is marked to not transform like a vector. "
f"It was treated like a vector by the adapter. Handle '{head_key}' estimates with care."
)
adapted = self.adapter(
{"inference_variables": estimate},
inverse=True,
strict=False,
**kwargs,
)
processed[score_key][head_key] = adapted
return processed
@staticmethod
def _reorder_estimates(
estimates: Mapping[str, Mapping[str, Mapping[str, np.ndarray]]],
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
"""Reorders the nested dictionary so that the inference variable names become the top-level keys."""
sample_inner = next(iter(next(iter(estimates.values())).values()))
variable_names = sample_inner.keys()
reordered = {}
for variable in variable_names:
reordered[variable] = {}
for score_key, inner_dict in estimates.items():
reordered[variable][score_key] = {inner_key: value[variable] for inner_key, value in inner_dict.items()}
return reordered
@staticmethod
def _squeeze_estimates(
estimates: Mapping[str, Mapping[str, Mapping[str, np.ndarray]]],
) -> dict[str, dict[str, np.ndarray]]:
"""Squeezes each inner estimate dictionary to remove unnecessary nesting."""
squeezed = {}
for variable, variable_estimates in estimates.items():
squeezed[variable] = {
score_key: squeeze_inner_estimates_dict(inner_estimate)
for score_key, inner_estimate in variable_estimates.items()
}
return squeezed
@staticmethod
def _squeeze_parametric_score_major_dict(samples: Mapping[str, np.ndarray]) -> np.ndarray | dict[str, np.ndarray]:
"""Squeezes the dictionary to just the value if there is only one key-value pair."""
if len(samples) == 1:
return next(iter(samples.values()))
return samples