from collections.abc import Mapping, Sequence
import keras
import numpy as np
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
from bayesflow.adapters import Adapter
from bayesflow.datasets import OnlineDataset
from bayesflow.networks import SummaryNetwork
from bayesflow.simulators import ModelComparisonSimulator, Simulator
from bayesflow.types import Shape, Tensor
from bayesflow.utils import filter_kwargs, logging
from .approximator import Approximator
[docs]
@serializable(package="bayesflow.approximators")
class ModelComparisonApproximator(Approximator):
"""
Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are
learned with a classifier.
Parameters
----------
adapter: bf.adapters.Adapter
Adapter for data pre-processing.
num_models: int
Number of models (simulators) that the approximator will compare
classifier_network: keras.Layer
The network backbone (e.g, an MLP) that is used for model classification.
The input of the classifier network is created by concatenating `classifier_variables`
and (optional) output of the summary_network.
summary_network: bg.networks.SummaryNetwork, optional
The summary network used for data summarization (default is None).
The input of the summary network is `summary_variables`.
"""
def __init__(
self,
*,
num_models: int,
classifier_network: keras.Layer,
adapter: Adapter,
summary_network: SummaryNetwork = None,
**kwargs,
):
super().__init__(**kwargs)
self.classifier_network = classifier_network
self.adapter = adapter
self.summary_network = summary_network
self.num_models = num_models
self.logits_projector = keras.layers.Dense(num_models)
[docs]
def build(self, data_shapes: Mapping[str, Shape]):
data = {key: keras.ops.zeros(value) for key, value in data_shapes.items()}
self.compute_metrics(**data, stage="training")
[docs]
@classmethod
def build_adapter(
cls,
num_models: int,
classifier_conditions: Sequence[str] = None,
summary_variables: Sequence[str] = None,
model_index_name: str = "model_indices",
):
if classifier_conditions is None and summary_variables is None:
raise ValueError("At least one of `classifier_variables` or `summary_variables` must be provided.")
adapter = Adapter().to_array().convert_dtype("float64", "float32")
if classifier_conditions is not None:
adapter = adapter.concatenate(classifier_conditions, into="classifier_conditions")
if summary_variables is not None:
adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables")
adapter = (
adapter.rename(model_index_name, "model_indices")
.keep(["classifier_conditions", "summary_variables", "model_indices"])
.standardize(exclude="model_indices")
.one_hot("model_indices", num_models)
)
return adapter
[docs]
@classmethod
def build_dataset(
cls,
*,
dataset: keras.utils.PyDataset = None,
simulator: ModelComparisonSimulator = None,
simulators: Sequence[Simulator] = None,
**kwargs,
) -> OnlineDataset:
if sum(arg is not None for arg in (dataset, simulator, simulators)) != 1:
raise ValueError("Exactly one of dataset, simulator, or simulators must be provided.")
if simulators is not None:
simulator = ModelComparisonSimulator(simulators)
return super().build_dataset(dataset=dataset, simulator=simulator, **kwargs)
[docs]
def compile(
self,
*args,
classifier_metrics: Sequence[keras.Metric] = None,
summary_metrics: Sequence[keras.Metric] = None,
**kwargs,
):
if classifier_metrics:
self.classifier_network._metrics = classifier_metrics
if summary_metrics:
if self.summary_network is None:
logging.warning("Ignoring summary metrics because there is no summary network.")
else:
self.summary_network._metrics = summary_metrics
return super().compile(*args, **kwargs)
[docs]
def compute_metrics(
self,
*,
classifier_conditions: Tensor = None,
model_indices: Tensor,
summary_variables: Tensor = None,
stage: str = "training",
) -> dict[str, Tensor]:
if self.summary_network is None:
summary_metrics = {}
else:
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
summary_outputs = summary_metrics.pop("outputs")
if classifier_conditions is None:
classifier_conditions = summary_outputs
else:
classifier_conditions = keras.ops.concatenate([classifier_conditions, summary_outputs], axis=-1)
# we could move this into its own class
logits = self.classifier_network(classifier_conditions)
logits = self.logits_projector(logits)
cross_entropy = keras.losses.categorical_crossentropy(model_indices, logits, from_logits=True)
cross_entropy = keras.ops.mean(cross_entropy)
classifier_metrics = {"loss": cross_entropy}
if stage != "training" and any(self.classifier_network.metrics):
# compute sample-based metrics
predictions = keras.ops.argmax(logits, axis=-1)
classifier_metrics |= {
metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics
}
loss = classifier_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
classifier_metrics = {f"{key}/classifier_{key}": value for key, value in classifier_metrics.items()}
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}
metrics = {"loss": loss} | classifier_metrics | summary_metrics
return metrics
[docs]
def fit(
self,
*,
adapter: Adapter = "auto",
dataset: keras.utils.PyDataset = None,
simulator: ModelComparisonSimulator = None,
simulators: Sequence[Simulator] = None,
**kwargs,
):
"""
Trains the approximator on the provided dataset or on-demand generated from the given (multi-model) simulator.
If `dataset` is not provided, a dataset is built from the `simulator`.
If `simulator` is not provided, it will be build from a list of `simulators`.
If the model has not been built, it will be built using a batch from the dataset.
Parameters
----------
dataset : keras.utils.PyDataset, optional
A dataset containing simulations for training. If provided, `simulator` must be None.
simulator : ModelComparisonSimulator, optional
A simulator used to generate a dataset. If provided, `dataset` must be None.
simulators: Sequence[Simulator], optional
A list of simulators (one simulator per model). If provided, `dataset` must be None.
**kwargs : dict
Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
batch_size : int or None, default='auto'
Number of samples per gradient update. Do not specify if `dataset` is provided as a
`keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function.
epochs : int, default=1
Number of epochs to train the model.
verbose : {"auto", 0, 1, 2}, default="auto"
Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
callbacks : list of keras.callbacks.Callback, optional
List of callbacks to apply during training.
validation_split : float, optional
Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays
or tensors).
validation_data : tuple or dataset, optional
Data for validation, overriding `validation_split`.
shuffle : bool, default=True
Whether to shuffle the training data before each epoch (ignored for dataset generators).
initial_epoch : int, default=0
Epoch at which to start training (useful for resuming training).
steps_per_epoch : int or None, optional
Number of steps (batches) before declaring an epoch finished.
validation_steps : int or None, optional
Number of validation steps per validation epoch.
validation_batch_size : int or None, optional
Number of samples per validation batch (defaults to `batch_size`).
validation_freq : int, default=1
Specifies how many training epochs to run before performing validation.
Returns
-------
keras.callbacks.History
A history object containing the training loss and metrics values.
Raises
------
ValueError
If both `dataset` and `simulator` or `simulators` are provided or neither is provided.
"""
if dataset is not None:
if simulator is not None or simulators is not None:
raise ValueError(
"Received conflicting arguments. Please provide either a dataset or a simulator, but not both."
)
return super().fit(dataset=dataset, **kwargs)
if adapter == "auto":
logging.info("Building automatic data adapter.")
adapter = self.build_adapter(num_models=self.num_models, **filter_kwargs(kwargs, self.build_adapter))
if simulator is not None:
return super().fit(simulator=simulator, adapter=adapter, **kwargs)
logging.info(f"Building model comparison simulator from {len(simulators)} simulators.")
simulator = ModelComparisonSimulator(simulators=simulators)
return super().fit(simulator=simulator, adapter=adapter, **kwargs)
[docs]
@classmethod
def from_config(cls, config, custom_objects=None):
config["num_models"] = deserialize(config["num_models"], custom_objects=custom_objects)
config["adapter"] = deserialize(config["adapter"], custom_objects=custom_objects)
(config["classifier_network"],) = deserialize(config["classifier_network"], custom_objects=custom_objects)
config["summary_network"] = deserialize(config["summary_network"], custom_objects=custom_objects)
return super().from_config(config, custom_objects=custom_objects)
[docs]
def get_config(self):
base_config = super().get_config()
config = {
"num_models": serialize(self.num_models),
"adapter": serialize(self.adapter),
"classifier_network": serialize(self.classifier_network),
"summary_network": serialize(self.summary_network),
}
return base_config | config
[docs]
def predict(
self,
*,
conditions: dict[str, np.ndarray],
logits: bool = False,
**kwargs,
) -> np.ndarray:
"""
Predicts posterior model probabilities given input conditions. The `conditions` dictionary is preprocessed
using the `adapter`. The output is converted to NumPy array after inference.
Parameters
----------
conditions : dict[str, np.ndarray]
Dictionary of conditioning variables as NumPy arrays.
logits: bool, default=False
Should the posterior model probabilities be on the (unconstrained) logit space?
If `False`, the output is a unit simplex instead.
**kwargs : dict
Additional keyword arguments for the adapter and classification process.
Returns
-------
np.ndarray
Predicted posterior model probabilities given `conditions`.
"""
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
# at inference time, model_indices are predicted by the networks and thus ignored in conditions
conditions.pop("model_indices", None)
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
output = self._predict(**conditions, **kwargs)
if not logits:
output = keras.ops.softmax(output)
output = keras.ops.convert_to_numpy(output)
return output
def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tensor = None, **kwargs) -> Tensor:
if self.summary_network is None:
if summary_variables is not None:
raise ValueError("Cannot use summary variables without a summary network.")
else:
if summary_variables is None:
raise ValueError("Summary variables are required when a summary network is present")
summary_outputs = self.summary_network(
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
)
if classifier_conditions is None:
classifier_conditions = summary_outputs
else:
classifier_conditions = keras.ops.concatenate([classifier_conditions, summary_outputs], axis=1)
output = self.classifier_network(classifier_conditions)
output = self.logits_projector(output)
return output