Source code for bayesflow.approximators.ratio_approximator

from collections.abc import Sequence, Mapping

import numpy as np
import keras

from bayesflow.types import Tensor
from bayesflow.adapters import Adapter
from bayesflow.utils import expand_tile, concatenate_valid_shapes, weighted_mean
from bayesflow.utils.serialization import serialize, serializable

from .approximator import Approximator
from .helpers import ConditionBuilder

from ..networks.helpers import Standardization


[docs] @serializable("bayesflow.approximators") class RatioApproximator(Approximator): """ Implements contrastive neural likelihood-to-evidence ratio estimation (NRE-C) as described in https://arxiv.org/pdf/2210.06170. The estimation target is the ratio of likelihood and evidence: p(x | theta) / p(x). Parameters ---------- inference_network : keras.Layer A network backbone to perform contrastive learning. Last logits layer is automatically added on top of the inference network. adapter : bayesflow.adapters.Adapter, optional Adapter for data processing. You can use :py:meth:`build_adapter` to create it. If ``None`` (default), an identity adapter is used that makes a shallow copy and passes data through unchanged. summary_network : SummaryNetwork, optional The summary network used for data summarization of summary_variables (default is None). When present, summary outputs are automatically concatenated with inference_conditions. gamma: float, optional Odds or of any pair being drawn dependently to completely independently. Default is 1. K: int, optional Number of parameter candidates used for contrastive learning. Default is 5. standardize : str | Sequence[str] | None The variables to standardize before passing to the networks. Can be either "all" or any subset of ["inference_variables", "inference_conditions", "summary_variables"]. (default is "inference_variables"). **kwargs : dict, optional Additional arguments passed to the :py:class:`bayesflow.approximators.Approximator` class. """ def __init__( self, *, inference_network: keras.Layer, adapter: Adapter = None, summary_network: keras.Layer = None, gamma: float = 1.0, K: int = 5, standardize: str | Sequence[str] | None = "inference_variables", **kwargs, ): super().__init__(**kwargs) self.adapter = adapter if adapter is not None else Adapter() self.inference_network = inference_network self.summary_network = summary_network self.condition_builder = ConditionBuilder() if gamma <= 0: raise ValueError(f"Gamma must be positive, got {gamma}.") if gamma == float("inf"): raise NotImplementedError("NRE-B is not yet supported.") if K <= 0: raise ValueError(f"K must be positive, got {K}.") self.gamma = gamma self.K = K self.projector = keras.layers.Dense(units=1) self.seed_generator = keras.random.SeedGenerator() self.standardizer = Standardization(standardize)
[docs] def build(self, data_shapes: Mapping[str, Tensor]): self._build_standardization_layers(data_shapes) # Build summary network once at template level and cache output shape summary_outputs_shape = self._build_summary_network(data_shapes) classifier_inputs_shape = concatenate_valid_shapes( [data_shapes["inference_variables"], data_shapes["inference_conditions"], summary_outputs_shape], axis=-1 ) if not self.inference_network.built: self.inference_network.build(classifier_inputs_shape) classifier_outputs_shape = self.inference_network.compute_output_shape(classifier_inputs_shape) if not self.projector.built: self.projector.build(classifier_outputs_shape) self.built = True
[docs] def compute_metrics( self, inference_variables: Tensor, inference_conditions: Tensor = None, summary_variables: Tensor = None, sample_weight: Tensor = None, summary_attention_mask: Tensor = None, summary_mask: Tensor = None, inference_attention_mask: Tensor = None, inference_mask: Tensor = None, stage: str = "training", ) -> dict[str, Tensor]: """ Computes loss following https://arxiv.org/pdf/2210.06170. Handles both summary network outputs (if present) and inference conditions, combining them via ConditionBuilder.resolve(). """ inference_variables = self.standardizer.maybe_standardize( inference_variables, key="inference_variables", stage=stage ) masks = dict( summary_attention_mask=summary_attention_mask, summary_mask=summary_mask, inference_attention_mask=inference_attention_mask, inference_mask=inference_mask, ) summary_kwargs = self._collect_mask_kwargs(self._SUMMARY_MASK_KEYS, masks) inference_kwargs = self._collect_mask_kwargs(self._INFERENCE_MASK_KEYS, masks) resolved_conditions, summary_metrics = self._standardize_and_resolve( inference_conditions, summary_variables, stage=stage, purpose="metrics", **summary_kwargs ) batch_size = keras.ops.shape(inference_variables)[0] log_gamma = keras.ops.broadcast_to(keras.ops.log(self.gamma), (batch_size,)) log_K = keras.ops.broadcast_to(keras.ops.log(self.K), (batch_size,)) marginal_weight = 1 / (1 + self.gamma) joint_weight = self.gamma / (1 + self.gamma) # Get (batch_size, K+1, dim) inference variables (theta) bootstrap_inference_variables = self._sample_from_batch(inference_variables) bootstrap_inference_variables = keras.ops.concatenate( [inference_variables[:, None, :], bootstrap_inference_variables], axis=1 ) # Get (batch_size, K, dim) conditions (already resolved from condition builder) conditions = expand_tile(resolved_conditions, n=self.K, axis=1) marginal_logits = self.logits( bootstrap_inference_variables[:, 1:, :], conditions, stage=stage, **inference_kwargs ) joint_logits = self.logits( bootstrap_inference_variables[:, :-1, :], conditions, stage=stage, **inference_kwargs ) # Eq. 7 (https://arxiv.org/abs/2210.06170) - we use a trick for numerical stability: # log(K + gamma * sum_{i=1}^{K} exp(h_i)) = log(exp(log K) + sum_{i=1}^{K} exp(h_i + log gamma)) # so if we absorb log gamma into the network outputs and concatenate log K, we can use logsumexp log_numerator_joint = log_gamma + joint_logits[:, 0] log_denominator_joint = keras.ops.concatenate([log_gamma[:, None] + joint_logits, log_K[:, None]], axis=-1) log_denominator_joint = keras.ops.logsumexp(log_denominator_joint, axis=-1) log_numerator_marginal = log_K log_denominator_marginal = keras.ops.concatenate( [log_gamma[:, None] + marginal_logits, log_K[:, None]], axis=-1 ) log_denominator_marginal = keras.ops.logsumexp(log_denominator_marginal, axis=-1) joint_loss = log_denominator_joint - log_numerator_joint marginal_loss = log_denominator_marginal - log_numerator_marginal loss = marginal_weight * marginal_loss + joint_weight * joint_loss inference_loss = weighted_mean(loss, sample_weight) # Handle summary network metrics if present if "loss" in summary_metrics: total_loss = inference_loss + summary_metrics["loss"] else: total_loss = inference_loss # Format metrics with prefixes inference_metrics = {"loss": total_loss} summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()} metrics = inference_metrics | summary_metrics return metrics
[docs] def fit(self, *args, **kwargs): """ Trains the approximator on the provided dataset or on-demand data generated from the given simulator. If `dataset` is not provided, a dataset is built from the `simulator`. If the model has not been built, it will be built using a batch from the dataset. Parameters ---------- dataset : keras.utils.PyDataset, optional A dataset containing simulations for training. If provided, `simulator` must be None. simulator : Simulator, optional A simulator used to generate a dataset. If provided, `dataset` must be None. **kwargs Additional keyword arguments passed to `keras.Model.fit()`, as described in: https://github.com/keras-team/keras/blob/v3.13.2/keras/src/backend/tensorflow/trainer.py#L314 Returns ------- keras.callbacks.History A history object containing the training loss and metrics values. Raises ------ ValueError If both `dataset` and `simulator` are provided or neither is provided. """ return super().fit(*args, **kwargs, adapter=self.adapter)
[docs] def log_ratio(self, data: Mapping[str, np.ndarray], **kwargs) -> Tensor: """ Computes the log likelihood-to-evidence ratio. The `data` dictionary is preprocessed using the `adapter`. Parameters ---------- data : Mapping[str, np.ndarray] Dictionary of observed data as NumPy arrays. **kwargs : dict Additional keyword arguments for the adapter and log-probability computation. Returns ------- log_ratio: Tensor The estimated log ratios. """ resolved_conditions, adapted, _ = self._prepare_conditions(data, **kwargs) inference_variables = self.standardizer.maybe_standardize( adapted.get("inference_variables"), key="inference_variables", stage="inference" ) inference_kwargs = self._collect_mask_kwargs(self._INFERENCE_MASK_KEYS, adapted) log_ratio = self.logits(inference_variables, resolved_conditions, stage="inference", **inference_kwargs) return log_ratio
[docs] def logits(self, inference_variables: Tensor, inference_conditions: Tensor, stage: str, **kwargs) -> Tensor: """Computes logits for K batches of variables-conditions pairs.""" classifier_inputs = keras.ops.concatenate([inference_variables, inference_conditions], axis=-1) logits = self.inference_network(classifier_inputs, training=stage == "training", **kwargs) logits = self.projector(logits) logits = keras.ops.squeeze(logits, axis=-1) return logits
[docs] def get_config(self): base_config = super().get_config() config = { "adapter": self.adapter, "summary_network": self.summary_network, "inference_network": self.inference_network, "gamma": self.gamma, "K": self.K, "standardize": self.standardizer.standardize, } return base_config | serialize(config)
def _sample_from_batch(self, inference_variables: Tensor) -> Tensor: """Samples K batches of inference variables with replacement. Ensures that no self-sampling occurs (i.e., all samples are negative examples).""" B = keras.ops.shape(inference_variables)[0] r = keras.random.randint( shape=(B, self.K), minval=0, maxval=B - 1, dtype="int32", seed=self.seed_generator, ) i = keras.ops.expand_dims(keras.ops.arange(B, dtype="int32"), axis=1) idx = r + keras.ops.cast(r >= i, "int32") return keras.ops.take(inference_variables, idx, axis=0)