Source code for bayesflow.networks.diffusion_model.diffusion_model

from collections.abc import Sequence
from typing import Literal, Callable

import keras
from keras import ops

from ..inference_network import InferenceNetwork
from bayesflow.types import Tensor, Shape
from bayesflow.utils import (
    expand_right_as,
    find_network,
    jacobian_trace,
    layer_kwargs,
    weighted_mean,
    integrate,
    integrate_stochastic,
    logging,
    STOCHASTIC_METHODS,
    DETERMINISTIC_METHODS,
)
from bayesflow.utils.serialization import serialize, deserialize, serializable

from .schedules.noise_schedule import NoiseSchedule
from .dispatch import find_noise_schedule


[docs] @serializable("bayesflow.networks") class DiffusionModel(InferenceNetwork): """Score-based diffusion model for simulation-based inference as described in [1]: [1] Arruda, J., Bracher, N., Köthe, U., Hasenauer, J., & Radev, S. T. (2025). Diffusion Models in Simulation-Based Inference: A Tutorial Review. arXiv preprint arXiv:2512.20685. """ TIME_MLP_DEFAULT_CONFIG = { "widths": (256, 256, 256, 256, 256), "activation": "mish", "kernel_initializer": "he_normal", "residual": True, "dropout": 0.05, "spectral_normalization": False, "time_embedding_dim": 32, "merge": "concat", "norm": "layer", } INTEGRATE_DEFAULT_CONFIG = { "method": "two_step_adaptive", "steps": "adaptive", } def __init__( self, *, subnet: str | type | keras.Layer = "time_mlp", noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm", prediction_type: Literal["velocity", "noise", "F", "x"] = "F", loss_type: Literal["velocity", "noise", "F"] = "noise", subnet_kwargs: dict[str, any] = None, schedule_kwargs: dict[str, any] = None, integrate_kwargs: dict[str, any] = None, **kwargs, ): """ Initializes a diffusion model with configurable subnet architecture, noise schedule, and prediction/loss types for amortized Bayesian inference. Note, that score-based diffusion is the most sluggish of all available samplers, so expect slower inference times than flow matching and much slower than normalizing flows. Parameters ---------- subnet : str, type or keras.Layer, optional A neural network type for the diffusion model, will be instantiated using subnet_kwargs. If a string is provided, it should be a registered name (e.g., "time_mlp"). If a type or keras.Layer is provided, it will be directly instantiated with the given ``subnet_kwargs``. Any subnet must accept a tuple of tensors (target, time, conditions). noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional Noise schedule controlling the diffusion dynamics. Can be a string identifier, a schedule class, or a pre-initialized schedule instance. Default is "edm". prediction_type : {'velocity', 'noise', 'F', 'x'}, optional Output format of the model's prediction. Default is "F". loss_type : {'velocity', 'noise', 'F'}, optional Loss function used to train the model. Default is "noise". subnet_kwargs : dict[str, any], optional Additional keyword arguments passed to the subnet constructor. Default is None. schedule_kwargs : dict[str, any], optional Additional keyword arguments passed to the noise schedule constructor. Default is None. integrate_kwargs : dict[str, any], optional Configuration dictionary for integration during training or inference. Default is None. **kwargs Additional keyword arguments passed to the base class and internal components. """ super().__init__(base_distribution="normal", **kwargs) if prediction_type not in ["noise", "velocity", "F", "x"]: raise ValueError(f"Unknown prediction type: {prediction_type}") if loss_type not in ["noise", "velocity", "F"]: raise ValueError(f"Unknown loss type: {loss_type}") if loss_type != "noise": logging.warning( "The standard schedules have weighting functions defined for the noise prediction loss. " "You might want to replace them if you are using a different loss function." ) self._prediction_type = prediction_type self._loss_type = loss_type schedule_kwargs = schedule_kwargs or {} self.noise_schedule = find_noise_schedule(noise_schedule, **schedule_kwargs) self.noise_schedule.validate() self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {}) self.seed_generator = keras.random.SeedGenerator() subnet_kwargs = subnet_kwargs or {} if subnet == "time_mlp": subnet_kwargs = DiffusionModel.TIME_MLP_DEFAULT_CONFIG | subnet_kwargs self.subnet = find_network(subnet, **subnet_kwargs) self.output_projector = None
[docs] def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: if self.built: return self.base_distribution.build(xz_shape) self.output_projector = keras.layers.Dense( units=xz_shape[-1], bias_initializer="zeros", name="output_projector", ) # construct input shape for subnet and subnet projector time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature self.subnet.build((xz_shape, time_shape, conditions_shape)) out_shape = self.subnet.compute_output_shape((xz_shape, time_shape, conditions_shape)) self.output_projector.build(out_shape)
[docs] def get_config(self): base_config = super().get_config() base_config = layer_kwargs(base_config) config = { "subnet": self.subnet, "noise_schedule": self.noise_schedule, "prediction_type": self._prediction_type, "loss_type": self._loss_type, "integrate_kwargs": self.integrate_kwargs, # we do not need to store subnet_kwargs } return base_config | serialize(config)
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))
[docs] def convert_prediction_to_x( self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor ) -> Tensor: """ Converts the neural network prediction into the denoised data `x`, depending on the prediction type configured for the model. Parameters ---------- pred : Tensor The output prediction from the neural network, typically representing noise, velocity, or a transformation of the clean signal. z : Tensor The noisy latent variable `z` to be denoised. alpha_t : Tensor The noise schedule's scaling factor for the clean signal at time `t`. sigma_t : Tensor The standard deviation of the noise at time `t`. log_snr_t : Tensor The log signal-to-noise ratio at time `t`. Returns ------- Tensor The reconstructed clean signal `x` from the model prediction. """ match self._prediction_type: case "velocity": return alpha_t * z - sigma_t * pred case "noise": return (z - sigma_t * pred) / alpha_t case "F": sigma_data = getattr(self.noise_schedule, "sigma_data", 1.0) x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2) x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) return x1 * z + x2 * pred case "x": return pred case "score": return (z + sigma_t**2 * pred) / alpha_t case _: raise ValueError(f"Unknown prediction type {self._prediction_type}.")
[docs] def score( self, xz: Tensor, time: float | Tensor = None, log_snr_t: Tensor = None, conditions: Tensor = None, training: bool = False, **kwargs, ) -> Tensor: """ Computes the score of the target or latent variable `xz`. Parameters ---------- xz : Tensor The current state of the latent variable `z`, typically of shape (..., D), where D is the dimensionality of the latent space. time : float or Tensor Scalar or tensor representing the time (or noise level) at which the velocity should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. log_snr_t : Tensor The log signal-to-noise ratio at time `t`. If None, time must be provided. conditions : Tensor, optional Conditional inputs to the network, such as conditioning variables or encoder outputs. Shape must be broadcastable with `xz`. Default is None. training : bool, optional Whether the model is in training mode. Affects behavior of dropout, batch norm, or other stochastic layers. Default is False. **kwargs Additional keyword arguments for custom guidance. The following dicts can be provided: guidance_constraints, containing any of the following keys: - constraints: Required constraint functions or a single function of xz - guidance_strength (float, optional): Strength of the constraint guidance. Defaults to 1.0. - scaling_function (callable, optional): Optional function to scale constraint values. Defaults to None. - reduce (str or callable, optional): Reduction method applied to constraints. Defaults to "sum". guidance_function: a function that takes x and time as keyword arguments and returns a guidance signal. Returns ------- Tensor The velocity tensor of the same shape as `xz`, representing the right-hand side of the probability-flow SDE or ODE at the given `time`. """ if log_snr_t is None: log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) if time is None: time = self.noise_schedule.get_t_from_log_snr(log_snr_t, training=training) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) subnet_out = self.subnet((xz, self._transform_log_snr(log_snr_t), conditions), training=training) pred = self.output_projector(subnet_out) x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) score = (alpha_t * x_pred - xz) / ops.square(sigma_t) # Optional constraints for guidance guidance_constraints = kwargs.get("guidance_constraints", None) if guidance_constraints is not None: guidance = self.guidance_constraint_term(x=x_pred, time=time, **guidance_constraints) score = score + guidance # Optional guidance function guidance_function = kwargs.get("guidance_function", None) if guidance_function is not None: guidance = guidance_function(x=x_pred, time=time) score = score + guidance return score
[docs] def velocity( self, xz: Tensor, time: float | Tensor, stochastic_solver: bool, conditions: Tensor = None, training: bool = False, **kwargs, ) -> Tensor: """ Computes the velocity (i.e., time derivative) of the target or latent variable `xz` for either a stochastic differential equation (SDE) or ordinary differential equation (ODE). Parameters ---------- xz : Tensor The current state of the latent variable `z`, typically of shape (..., D), where D is the dimensionality of the latent space. time : float or Tensor Scalar or tensor representing the time (or noise level) at which the velocity should be computed. Will be broadcasted to xz. stochastic_solver : bool If True, computes the velocity for the stochastic formulation (SDE). If False, uses the deterministic formulation (ODE). conditions : Tensor, optional Conditional inputs to the network, such as conditioning variables or encoder outputs. Shape must be broadcastable with `xz`. Default is None. training : bool, optional Whether the model is in training mode. Affects behavior of dropout, batch norm, or other stochastic layers. Default is False. Returns ------- Tensor The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training, **kwargs) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) if stochastic_solver: # for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW out = f - g_squared * score else: # for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt out = f - 0.5 * g_squared * score return out
[docs] def diffusion_term( self, xz: Tensor, time: float | Tensor, training: bool = False, ) -> Tensor: """ Compute the diffusion term (standard deviation of the noise) at a given time. Parameters ---------- xz : Tensor Input tensor of shape (..., D), typically representing the target or latent variables at given time. time : float or Tensor The diffusion time step(s). Can be a scalar or a tensor broadcastable to the shape of `xz`. training : bool, optional Whether to use the training noise schedule (default is False). Returns ------- Tensor The diffusion term tensor with shape matching `xz` except for the last dimension, which is set to 1. """ log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t) return ops.sqrt(g_squared)
def _velocity_trace( self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False, **kwargs, ) -> (Tensor, Tensor): def f(x): return self.velocity( x, time=time, stochastic_solver=False, conditions=conditions, training=training, **kwargs ) v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True) return v, ops.expand_dims(trace, axis=-1) def _transform_log_snr(self, log_snr: Tensor) -> Tensor: """Transform the log_snr to the range [-1, 1] for the diffusion process.""" log_snr_min = self.noise_schedule.log_snr_min log_snr_max = self.noise_schedule.log_snr_max normalized_snr = (log_snr - log_snr_min) / (log_snr_max - log_snr_min) scaled_value = 2 * normalized_snr - 1 return scaled_value def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if integrate_kwargs["method"] in STOCHASTIC_METHODS: logging.warning( "Stochastic methods are not supported for density evaluation." " Falling back to tsit5 ODE solver." " To suppress this warning, explicitly pass a method from " + str(DETERMINISTIC_METHODS) + "." ) integrate_kwargs["method"] = "tsit5" if density: def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) return {"xz": v, "trace": trace} state = { "xz": x, "trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)), } state = integrate( deltas, state, **integrate_kwargs, ) z = state["xz"] log_density = self.base_distribution.log_prob(z) + ops.squeeze(state["trace"], axis=-1) return z, log_density def deltas(time, xz): return { "xz": self.velocity(xz, time=time, stochastic_solver=False, conditions=conditions, training=training) } state = {"xz": x} state = integrate( deltas, state, **integrate_kwargs, ) z = state["xz"] return z def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} | self.integrate_kwargs | kwargs if density: if integrate_kwargs["method"] in STOCHASTIC_METHODS: logging.warning( "Stochastic methods are not supported for density computation." " Falling back to ODE solver." " Use one of the deterministic methods: " + str(DETERMINISTIC_METHODS) + "." ) integrate_kwargs["method"] = "tsit5" def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training, **kwargs) return {"xz": v, "trace": trace} state = { "xz": z, "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), } state = integrate(deltas, state, **integrate_kwargs) x = state["xz"] log_density = self.base_distribution.log_prob(z) - ops.squeeze(state["trace"], axis=-1) return x, log_density state = {"xz": z} if integrate_kwargs["method"] in STOCHASTIC_METHODS: def deltas(time, xz): return { "xz": self.velocity( xz, time=time, stochastic_solver=True, conditions=conditions, training=training, **kwargs ) } def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} score_fn = None if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin": def score_fn(time, xz): return {"xz": self.score(xz, time=time, conditions=conditions, training=training, **kwargs)} state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, score_fn=score_fn, noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, ) else: def deltas(time, xz): return { "xz": self.velocity( xz, time=time, stochastic_solver=False, conditions=conditions, training=training, **kwargs ) } state = integrate( deltas, state, **integrate_kwargs, ) x = state["xz"] return x
[docs] def compute_metrics( self, x: Tensor | Sequence[Tensor, ...], conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training", ) -> dict[str, Tensor]: training = stage == "training" # use same noise schedule for training and validation to keep them comparable noise_schedule_training_stage = stage == "training" or stage == "validation" if not self.built: xz_shape = ops.shape(x) conditions_shape = None if conditions is None else ops.shape(conditions) self.build(xz_shape, conditions_shape) # sample training diffusion time as low discrepancy sequence to decrease variance u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x), seed=self.seed_generator) i = ops.arange(0, ops.shape(x)[0], dtype=ops.dtype(x)) t = (u0 + i / ops.cast(ops.shape(x)[0], dtype=ops.dtype(x))) % 1 # calculate the noise level log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t, training=noise_schedule_training_stage), x) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) weights_for_snr = self.noise_schedule.get_weights_for_snr(log_snr_t=log_snr_t) # generate noise vector eps_t = keras.random.normal(ops.shape(x), dtype=ops.dtype(x), seed=self.seed_generator) # diffuse x diffused_x = alpha_t * x + sigma_t * eps_t # calculate output of the network subnet_out = self.subnet((diffused_x, self._transform_log_snr(log_snr_t), conditions), training=training) pred = self.output_projector(subnet_out) x_pred = self.convert_prediction_to_x( pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t ) # convert predicted target (x_pred) to corresponding diffusion prediction match self._loss_type: case "noise": noise_pred = (diffused_x - alpha_t * x_pred) / sigma_t loss = weights_for_snr * ops.mean((noise_pred - eps_t) ** 2, axis=-1) case "velocity": velocity_pred = (alpha_t * diffused_x - x_pred) / sigma_t v_t = alpha_t * eps_t - sigma_t * x loss = weights_for_snr * ops.mean((velocity_pred - v_t) ** 2, axis=-1) case "F": sigma_data = self.noise_schedule.sigma_data if hasattr(self.noise_schedule, "sigma_data") else 1.0 x1 = ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) / (ops.exp(-log_snr_t / 2) * sigma_data) x2 = (sigma_data * alpha_t) / (ops.exp(-log_snr_t / 2) * ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)) f_pred = x1 * x_pred - x2 * diffused_x f_t = x1 * x - x2 * diffused_x loss = weights_for_snr * ops.mean((f_pred - f_t) ** 2, axis=-1) case _: raise ValueError(f"Unknown loss type: {self._loss_type}") loss = weighted_mean(loss, sample_weight) base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss}
[docs] def guidance_constraint_term( self, x: Tensor, time: Tensor, constraints: Callable | Sequence[Callable], guidance_strength: float = 1.0, scaling_function: Callable | None = None, reduce: Literal["sum", "mean"] = "sum", ) -> Tensor: """ Backend-agnostic implementation of: `∇_x Σ_k log sigmoid( -s(t) * c_k(x) )` Parameters ---------- x : Tensor The denoised target at time t. time : Tensor The time corresponding to x. constraints : Callable or Sequence[Callable] A single constraint function or a list/tuple of constraint functions. Each function should take x as input and return a tensor of constraint values. guidance_strength : float, optional A positive scaling factor for the guidance term. Default is 1.0. scaling_function : Callable, optional A function that takes time t as input and returns a scaling factor s(t). If None, a default scaling based on the noise schedule is used. Default is None. reduce : {'sum', 'mean'}, optional Method to reduce the log-probabilities from multiple constraints. Default is 'sum'. Returns ------- Tensor The computed guidance term of the same shape as zt. """ if not isinstance(constraints, Sequence): constraints = [constraints] if scaling_function is None: def scaling_function(t: Tensor): log_snr = self.noise_schedule.get_log_snr(t, training=False) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr) return ops.square(alpha_t) / ops.square(sigma_t) def objective_fn(z): st = scaling_function(time) logp = keras.ops.zeros((), dtype=z.dtype) for c in constraints: ck = c(z) logp = logp - keras.ops.softplus(st * ck) return keras.ops.sum(logp) if reduce == "sum" else keras.ops.mean(logp) backend = keras.backend.backend() match backend: case "jax": import jax grad = jax.grad(objective_fn)(x) case "tensorflow": import tensorflow as tf with tf.GradientTape() as tape: tape.watch(x) objective = objective_fn(x) grad = tape.gradient(objective, x) case "torch": import torch with torch.enable_grad(): x_grad = x.clone().detach().requires_grad_(True) objective = objective_fn(x_grad) grad = torch.autograd.grad( outputs=objective, inputs=x_grad, )[0] case _: raise NotImplementedError(f"Unsupported backend: {backend}") return guidance_strength * grad