from collections.abc import Sequence, Callable
from typing import Literal, Tuple
import time
import copy
import numpy as np
import keras
from bayesflow.adapters import Adapter
from bayesflow.approximators import CompositionalApproximator
from bayesflow.networks import InferenceNetwork, SummaryNetwork, DiffusionModel
from bayesflow.simulators import Simulator
from bayesflow.types import Tensor
from bayesflow.utils import find_inference_network, find_summary_network, logging, format_duration, filter_kwargs
from .basic_workflow import BasicWorkflow
[docs]
class CompositionalWorkflow(BasicWorkflow):
"""
This class extends the Basic Workflow to support compositional inference, allowing for the generation of
samples conditioned on multiple datasets or compositional conditions.
Parameters
----------
simulator : Simulator, optional
A Simulator object to generate synthetic data for inference (default is None).
adapter : Adapter, optional
Adapter for data processing. If not provided, a default adapter will be used (default is None), but
you need to make sure to provide the correct names for `inference_variables` and/or `inference_conditions`
and/or `summary_variables`.
inference_network : InferenceNetwork or str, optional
The inference network used for posterior approximation, specified as an instance or by
name (default is "diffusion_model").
summary_network : SummaryNetwork or str, optional
The summary network used for data summarization, specified as an instance or by name (default is None).
initial_learning_rate : float, optional
Initial learning rate for the optimizer (default is 5e-4).
optimizer : type, optional
The optimizer to be used for training. If None, a default Adam optimizer will be selected (default is None).
checkpoint_filepath : str, optional
Directory path where model checkpoints will be saved (default is None).
checkpoint_name : str, optional
Name of the checkpoint file (default is "model").
save_weights_only : bool, optional
If True, only the model weights will be saved during checkpointing (default is False).
save_best_only: bool, optional
If only the best model according to the quantity monitored (loss or validation) at the end of
each epoch will be saved instead of the last model (default is False). Use with caution,
as some losses (e.g. flow matching) do not reliably reflect model performance, and outliers in the
validation data can cause unwanted effects.
inference_variables : Sequence[str] or str, optional
Variables for inference as a sequence of strings or a single string (default is None).
Important for automating diagnostics!
inference_conditions : Sequence[str] or str, optional
Variables used as direct conditions for inference (default is None).
summary_variables : Sequence[str] or str, optional
Variables to be summarized through the summary network before being used as conditions (default is None).
standardize : Sequence[str] or str, optional
Variables to standardize during preprocessing (default is "inference_variables"). These will be
passed to the corresponding approximator constructor and can be either "all" or any subset of
["inference_variables", "summary_variables", "inference_conditions"].
**kwargs : dict, optional
Additional arguments for configuring networks, adapters, optimizers, etc.
"""
def __init__(
self,
simulator: Simulator | None = None,
adapter: Adapter | None = None,
inference_network: InferenceNetwork | str = "diffusion_model",
summary_network: SummaryNetwork | str | None = None,
initial_learning_rate: float = 5e-4,
optimizer: keras.optimizers.Optimizer | type | None = None,
checkpoint_filepath: str | None = None,
checkpoint_name: str = "model",
save_weights_only: bool = False,
save_best_only: bool = False,
inference_variables: Sequence[str] | str | None = None,
inference_conditions: Sequence[str] | str | None = None,
summary_variables: Sequence[str] | str | None = None,
standardize: Sequence[str] | str | None = "inference_variables",
**kwargs,
):
self.inference_network = find_inference_network(inference_network, **kwargs.get("inference_kwargs", {}))
if not isinstance(self.inference_network, DiffusionModel):
raise ValueError("Inference network currently must be a DiffusionModel for compositional inference.")
if summary_network is not None:
self.summary_network = find_summary_network(summary_network, **kwargs.get("summary_kwargs", {}))
else:
self.summary_network = None
self.simulator = simulator
adapter = adapter or BasicWorkflow.default_adapter(inference_variables, inference_conditions, summary_variables)
self.approximator = CompositionalApproximator(
inference_network=self.inference_network,
summary_network=self.summary_network,
adapter=adapter,
standardize=standardize,
**filter_kwargs(kwargs, keras.Model.__init__),
)
self._init_optimizer(initial_learning_rate, optimizer, **kwargs.get("optimizer_kwargs", {}))
self._init_checkpointing(checkpoint_filepath, checkpoint_name, save_weights_only, save_best_only)
self.history = None
self._needs_compile = True
[docs]
def compositional_sample(
self,
*,
num_samples: int,
conditions: dict[str, np.ndarray] | None = None,
compute_prior_score: Callable[[dict[str, np.ndarray], np.ndarray | None], dict[str, np.ndarray]] = None,
summaries: Tensor | np.ndarray | None = None,
split: bool = False,
batch_size: int | None = None,
sample_shape: Literal["infer"] | Tuple[int] | int = "infer",
**kwargs,
) -> dict[str, np.ndarray]:
"""
Draws `num_samples` samples from the approximator given specified composition conditions.
The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...).
Parameters
----------
num_samples : int
The number of samples to generate.
conditions : dict[str, np.ndarray], optional
A dictionary where keys represent variable names and values are
NumPy arrays containing the adapted simulated variables. Keys used as summary or inference
conditions during training should be present.
Should have shape (n_datasets, n_compositional_conditions, ...).
compute_prior_score : Callable[[dict[str, np.ndarray], np.ndarray | None], dict[str, np.ndarray]], optional
A function that computes the score of the log prior distribution. Optionally, the function can have a time
argument, otherwise the prior score is multiplied with (1-t), where t is diffusion time.
If none provided, the unconditional score is used.
summaries : Tensor | np.ndarray | None, optional
Precomputed summary outputs to be used as conditions for sampling. If provided, these will be used instead
of the conditions. Should have shape (n_datasets, n_compositional_conditions, ...).
split : bool, default=False
Whether to split the output arrays along the last axis and return one sample array per target variable.
batch_size : int or None, optional
If provided, the conditions are split into batches of size `batch_size`, for which samples are generated
sequentially. Can help with memory management for large sample sizes.
sample_shape : str or tuple of int, optional
Trailing structural dimensions of each generated sample, excluding the batch and target (intrinsic)
dimension. For example, use `(time,)` for time series or `(height, width)` for images.
If set to `"infer"` (default), the structural dimensions are inferred from the `inference_conditions`.
In that case, all non-vector dimensions except the last (channel) dimension are treated as structural
dimensions. For example, if the final `inference_conditions` have shape `(batch_size, time, channels)`,
then `sample_shape` is inferred as `(time,)`, and the generated samples will have shape
`(num_conditions, num_samples, time, target_dim)`.
**kwargs : dict
Additional keyword arguments passed to the approximator's sampling function.
Returns
-------
dict[str, np.ndarray]
A dictionary where keys correspond to variable names and
values are arrays containing the generated samples.
"""
start_time = time.perf_counter()
samples = self.approximator.compositional_sample(
num_samples=num_samples,
conditions=conditions,
compute_prior_score=compute_prior_score,
split=split,
batch_size=batch_size,
sample_shape=sample_shape,
summary_outputs=summaries,
**kwargs,
)
elapsed = time.perf_counter() - start_time
logging.info(f"Sampling completed in {format_duration(elapsed)}.")
return samples
[docs]
@classmethod
def from_basic_workflow(
cls,
workflow: BasicWorkflow,
**kwargs,
) -> "CompositionalWorkflow":
"""
Build a :class:`CompositionalWorkflow` from a trained :class:`BasicWorkflow`.
The trained ``DiffusionModel`` inference network (and, if present, the summary
network) are transferred directly so no re-training is needed.
Parameters
----------
workflow : BasicWorkflow
A fitted workflow whose ``approximator.inference_network`` is a
:class:`~bayesflow.networks.DiffusionModel`.
**kwargs
Override any constructor argument of :class:`CompositionalWorkflow`,
e.g. ``optimizer``, ``simulator``, ``adapter``, etc.
The following attributes pertaining to checkpointing will not be
transferred from the source workflow:
- ``checkpoint_filepath``
- ``checkpoint_name``
- ``save_weights_only``
- ``save_best_only``
They can be set via kwargs if needed.
Returns
-------
compositional_workflow: CompositionalWorkflow
The newly created compositional workflow with attributes from ``workflow`` and
``kwargs``.
"""
if not isinstance(workflow, BasicWorkflow):
raise TypeError(f"Expected a BasicWorkflow instance, got {type(workflow).__name__!r}.")
approximator = workflow.approximator
# Clone the networks so the two workflows have independent weights
cloned_inference_network = keras.models.clone_model(approximator.inference_network)
cloned_inference_network.set_weights(approximator.inference_network.get_weights())
if not isinstance(approximator.inference_network, DiffusionModel):
raise ValueError(
f"The inference network must be a DiffusionModel for compositional inference, "
f"got {type(approximator.inference_network).__name__!r}."
)
if approximator.summary_network is not None:
cloned_summary_network = keras.models.clone_model(approximator.summary_network)
cloned_summary_network.set_weights(approximator.summary_network.get_weights())
else:
cloned_summary_network = approximator.summary_network
# Collect all attributes from the basic workflow that can be passed to the constructor.
init_kwargs = dict(
simulator=workflow.simulator,
adapter=approximator.adapter,
inference_network=cloned_inference_network,
summary_network=cloned_summary_network,
initial_learning_rate=workflow.initial_learning_rate,
optimizer=workflow.optimizer,
standardize=approximator.standardizer.standardize,
)
# Override with caller-supplied kwargs and create new workflow
compositional_workflow = cls(**(init_kwargs | kwargs))
# Replace the fresh (unfitted) standardizer with a deep copy of the source one
compositional_workflow.approximator.standardizer = copy.deepcopy(approximator.standardizer)
return compositional_workflow