Source code for bayesflow.networks.inference.diffusion.diffusion_model

from collections.abc import Sequence, Mapping
from typing import Any, Literal, Callable

import keras
from keras import ops

from bayesflow.types import Tensor, Shape
from bayesflow.utils import (
    expand_right_as,
    find_network,
    integrate,
    integrate_stochastic,
    jacobian_trace,
    layer_kwargs,
    logging,
    maybe_mask_tensor,
    random_mask,
    randomly_mask_along_axis,
    weighted_mean,
    STOCHASTIC_METHODS,
    DETERMINISTIC_METHODS,
)
from bayesflow.utils.serialization import serialize, serializable

from .schedules.noise_schedule import NoiseSchedule
from .dispatch import find_noise_schedule

from ...inference import InferenceNetwork
from ...defaults import TIME_MLP_DEFAULTS, DIFFUSION_INTEGRATE_DEFAULTS


[docs] @serializable("bayesflow.networks") class DiffusionModel(InferenceNetwork): """Score-based diffusion model for simulation-based inference (SBI). Implements a score-based diffusion model with configurable subnet architecture, noise schedule, and prediction/loss types for amortized SBI as described in [1]. Note that score-based diffusion is the most sluggish of all available samplers, so expect slower inference times than flow matching and much slower than normalizing flows. Parameters ---------- subnet : str, type, or keras.Layer, optional A neural network type for the diffusion model, will be instantiated using *subnet_kwargs*. If a string is provided, it should be a registered name (e.g., ``"time_mlp"``). If a type or ``keras.Layer`` is provided, it will be directly instantiated with the given *subnet_kwargs*. Any subnet must accept a tuple of tensors ``(target, time, conditions)``. noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional Noise schedule controlling the diffusion dynamics. Can be a string identifier, a schedule class, or a pre-initialised schedule instance. Default is ``"edm"``. prediction_type : {'velocity', 'noise', 'F', 'x'}, optional Output format of the model's prediction. Default is ``"F"``. loss_type : {'velocity', 'noise', 'F'}, optional Loss function used to train the model. Default is ``"noise"``. subnet_kwargs : dict[str, Any], optional Additional keyword arguments passed to the subnet constructor. schedule_kwargs : dict[str, Any], optional Additional keyword arguments passed to the noise schedule constructor. integrate_kwargs : dict[str, Any], optional Configuration dictionary for the ODE/SDE integrator used at inference time. drop_cond_prob : float, optional Probability of dropping conditions during training (i.e., classifier-free guidance). Default is 0.0. drop_target_prob : float, optional Probability of dropping target values during training (i.e., learning arbitrary distributions). Default is 0.0. **kwargs Additional keyword arguments passed to the base ``InferenceNetwork``. References ---------- [1] Arruda, J., Bracher, N., Köthe, U., Hasenauer, J., & Radev, S. T. (2025). Diffusion Models in Simulation-Based Inference: A Tutorial Review. arXiv preprint arXiv:2512.20685. """ def __init__( self, *, subnet: str | type | keras.Layer = "time_mlp", noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm", prediction_type: Literal["velocity", "noise", "F", "x"] = "F", loss_type: Literal["velocity", "noise", "F"] = "noise", subnet_kwargs: dict[str, Any] = None, schedule_kwargs: dict[str, Any] = None, integrate_kwargs: dict[str, Any] = None, drop_cond_prob: float = 0.0, drop_target_prob: float = 0.0, **kwargs, ): super().__init__(base_distribution="normal", **kwargs) if prediction_type not in ["noise", "velocity", "F", "x"]: raise ValueError(f"Unknown prediction type: {prediction_type}") if loss_type not in ["noise", "velocity", "F"]: raise ValueError(f"Unknown loss type: {loss_type}") if loss_type != "noise": logging.warning( "The standard schedules have weighting functions defined for the noise prediction loss. " "You might want to replace them if you are using a different loss function." ) self._prediction_type = prediction_type self._loss_type = loss_type schedule_kwargs = schedule_kwargs or {} self.noise_schedule = find_noise_schedule(noise_schedule, **schedule_kwargs) self.noise_schedule.validate() self.integrate_kwargs = DIFFUSION_INTEGRATE_DEFAULTS | (integrate_kwargs or {}) self.seed_generator = keras.random.SeedGenerator() subnet_kwargs = subnet_kwargs or {} if subnet == "time_mlp": subnet_kwargs = TIME_MLP_DEFAULTS | subnet_kwargs self.subnet = find_network(subnet, **subnet_kwargs) self.output_projector = None self.drop_cond_prob = drop_cond_prob self.unconditional_mode = False self.drop_target_prob = drop_target_prob
[docs] def compute_metrics( self, x: Tensor | Sequence[Tensor], conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training", **kwargs, ) -> dict[str, Tensor]: subnet_kwargs = self._collect_mask_kwargs(self._SUBNET_MASK_KEYS, kwargs) training = stage == "training" noise_schedule_training_stage = stage == "training" or stage == "validation" if conditions is not None: conditions = randomly_mask_along_axis(conditions, self.drop_cond_prob, seed_generator=self.seed_generator) # Sample training diffusion time as a low discrepancy sequence to decrease variance u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x), seed=self.seed_generator) i = ops.arange(0, ops.shape(x)[0], dtype=ops.dtype(x)) t = (u0 + i / ops.cast(ops.shape(x)[0], dtype=ops.dtype(x))) % 1 # Calculate the noise level log_snr_t = self.noise_schedule.get_log_snr(t, training=noise_schedule_training_stage) log_snr_t = expand_right_as(log_snr_t, x) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) weights_for_snr = self.noise_schedule.get_weights_for_snr(log_snr_t=log_snr_t) # Generate noise vector eps_t = keras.random.normal(ops.shape(x), dtype=ops.dtype(x), seed=self.seed_generator) # Diffuse x to get noisy input to the network diffused_x = alpha_t * x + sigma_t * eps_t # Generate optional target dropout mask mask_x = random_mask(ops.shape(x), self.drop_target_prob, self.seed_generator) diffused_x = maybe_mask_tensor(diffused_x, mask=mask_x, replacement=x) # Obtain output of the network and transform to prediction of the clean signal x norm_log_snr_t = self._transform_log_snr(log_snr_t) subnet_out = self.subnet((diffused_x, norm_log_snr_t, conditions), training=training, **subnet_kwargs) pred = self.output_projector(subnet_out) x_pred = self.convert_prediction_to_x( pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t ) # Finally, compute the loss according to the configured loss type. Note that the standard weighting # functions are defined for the noise prediction loss, so if you use a different loss type, you might want # to adjust the weighting accordingly. match self._loss_type: case "noise": noise_pred = (diffused_x - alpha_t * x_pred) / sigma_t loss = weights_for_snr * ops.mean(mask_x * (noise_pred - eps_t) ** 2, axis=-1) case "velocity": velocity_pred = (alpha_t * diffused_x - x_pred) / sigma_t v_t = alpha_t * eps_t - sigma_t * x loss = weights_for_snr * ops.mean(mask_x * (velocity_pred - v_t) ** 2, axis=-1) case "F": sigma_data = self.noise_schedule.sigma_data if hasattr(self.noise_schedule, "sigma_data") else 1.0 x1 = ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) / (ops.exp(-log_snr_t / 2) * sigma_data) x2 = (sigma_data * alpha_t) / (ops.exp(-log_snr_t / 2) * ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)) f_pred = x1 * x_pred - x2 * diffused_x f_t = x1 * x - x2 * diffused_x loss = weights_for_snr * ops.mean(mask_x * (f_pred - f_t) ** 2, axis=-1) case _: raise ValueError(f"Unknown loss type: {self._loss_type}") loss = weighted_mean(loss, sample_weight) return {"loss": loss}
[docs] def build(self, xz_shape: Shape, conditions_shape: Shape = None): if self.built: return self.base_distribution.build(xz_shape) self.output_projector = keras.layers.Dense(units=xz_shape[-1], bias_initializer="zeros") # construct input shape for subnet and subnet projector time_shape = (xz_shape[0], 1) # same batch dims, 1 feature self.subnet.build((xz_shape, time_shape, conditions_shape)) out_shape = self.subnet.compute_output_shape((xz_shape, time_shape, conditions_shape)) self.output_projector.build(out_shape)
[docs] def get_config(self): base_config = super().get_config() base_config = layer_kwargs(base_config) config = { "subnet": self.subnet, "noise_schedule": self.noise_schedule, "prediction_type": self._prediction_type, "loss_type": self._loss_type, "integrate_kwargs": self.integrate_kwargs, "drop_cond_prob": self.drop_cond_prob, "drop_target_prob": self.drop_target_prob, } return base_config | serialize(config)
[docs] def guidance_constraint_term( self, x: Tensor, time: Tensor, constraints: Callable | Sequence[Callable], guidance_strength: float = 1.0, scaling_function: Callable | None = None, reduce: Literal["sum", "mean"] = "sum", ) -> Tensor: """ Backend-agnostic implementation of: `∇_x Σ_k log sigmoid( -s(t) * c_k(x) )` Parameters ---------- x : Tensor The denoised target at time t. time : Tensor The time corresponding to x. constraints : Callable or Sequence[Callable] A single constraint function or a list/tuple of constraint functions. Each function should take x as input and return a tensor of constraint values. guidance_strength : float, optional A positive scaling factor for the guidance term. Default is 1.0. scaling_function : Callable, optional A function that takes time t as input and returns a scaling factor s(t). If None, a default scaling based on the noise schedule is used. Default is None. reduce : {'sum', 'mean'}, optional Method to reduce the log-probabilities from multiple constraints. Default is 'sum'. Returns ------- Tensor The computed guidance term of the same shape as zt. """ if not isinstance(constraints, Sequence): constraints = [constraints] if scaling_function is None: def scaling_function(t: Tensor): log_snr = self.noise_schedule.get_log_snr(t, training=False) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr) return ops.square(alpha_t) / ops.square(sigma_t) def objective_fn(z): st = scaling_function(time) logp = keras.ops.zeros((), dtype=z.dtype) for c in constraints: ck = c(z) logp = logp - keras.ops.softplus(st * ck) return keras.ops.sum(logp) if reduce == "sum" else keras.ops.mean(logp) backend = keras.backend.backend() match backend: case "jax": import jax grad = jax.grad(objective_fn)(x) case "tensorflow": import tensorflow as tf with tf.GradientTape() as tape: tape.watch(x) objective = objective_fn(x) grad = tape.gradient(objective, x) case "torch": import torch with torch.enable_grad(): x_grad = x.clone().detach().requires_grad_(True) objective = objective_fn(x_grad) grad = torch.autograd.grad( outputs=objective, inputs=x_grad, )[0] case _: raise NotImplementedError(f"Unsupported backend: {backend}") return guidance_strength * grad
[docs] def convert_prediction_to_x( self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor ) -> Tensor: """ Converts the neural network prediction into the denoised data `x`, depending on the prediction type configured for the model. Parameters ---------- pred : Tensor The output prediction from the neural network, typically representing noise, velocity, or a transformation of the clean signal. z : Tensor The noisy latent variable `z` to be denoised. alpha_t : Tensor The noise schedule's scaling factor for the clean signal at time `t`. sigma_t : Tensor The standard deviation of the noise at time `t`. log_snr_t : Tensor The log signal-to-noise ratio at time `t`. Returns ------- Tensor The reconstructed clean signal `x` from the model prediction. """ match self._prediction_type: case "velocity": return alpha_t * z - sigma_t * pred case "noise": return (z - sigma_t * pred) / alpha_t case "F": sigma_data = getattr(self.noise_schedule, "sigma_data", 1.0) x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2) x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) return x1 * z + x2 * pred case "x": return pred case "score": return (z + sigma_t**2 * pred) / alpha_t case _: raise ValueError(f"Unknown prediction type {self._prediction_type}.")
[docs] def score( self, xz: Tensor, time: float | Tensor = None, log_snr_t: Tensor = None, conditions: Tensor = None, training: bool = False, guidance_constraints: Mapping[str, Any] = None, guidance_function: Callable[[Tensor, Tensor], Tensor] = None, **kwargs, ) -> Tensor: """ Computes the score of the target or latent variable `xz`. Parameters ---------- xz : Tensor The current state of the latent variable `z`, typically of shape (..., D), where D is the dimensionality of the latent space. time : float or Tensor Scalar or tensor representing the time (or noise level) at which the velocity should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. log_snr_t : Tensor The log signal-to-noise ratio at time `t`. If None, time must be provided. conditions : Tensor, optional Conditional inputs to the network, such as conditioning variables or encoder outputs. Shape must be broadcastable with `xz`. Default is None. training : bool, optional Whether the model is in training mode. Affects behavior of dropout, batch norm, or other stochastic layers. Default is False. guidance_constraints : dict[str, Any], optional A dictionary of parameters for computing a guidance constraint term, which is added to the score for guided sampling. The specific keys and values depend on the implementation of `guidance_constraint_term`. guidance_function : Callable[[Tensor, Tensor], Tensor], optional A custom function for computing a guidance term, which is added to the score for guided sampling. The function should accept the predicted clean signal `x_pred` and the current time `time` as inputs and return a tensor of the same shape as `xz`. **kwargs Subnet kwargs (e.g., attention_mask, mask) for the subnet layer. Also supports guidance_constraints and guidance_function for custom guidance. Returns ------- Tensor The velocity tensor of the same shape as `xz`, representing the right-hand side of the probability-flow SDE or ODE at the given `time`. """ subnet_kwargs = self._collect_mask_kwargs(self._SUBNET_MASK_KEYS, kwargs) if log_snr_t is None: log_snr_t = self.noise_schedule.get_log_snr(t=time, training=training) log_snr_t = expand_right_as(log_snr_t, xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) if time is None: time = self.noise_schedule.get_t_from_log_snr(log_snr_t, training=training) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) norm_log_snr = self._transform_log_snr(log_snr_t) subnet_out = self.subnet((xz, norm_log_snr, conditions), training=training, **subnet_kwargs) pred = self.output_projector(subnet_out) x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) score = (alpha_t * x_pred - xz) / ops.square(sigma_t) if guidance_constraints is not None: guidance = self.guidance_constraint_term(x=x_pred, time=time, **guidance_constraints) score = score + guidance if guidance_function is not None: guidance = guidance_function(x=x_pred, time=time) score = score + guidance return score
[docs] def velocity( self, xz: Tensor, time: float | Tensor, stochastic_solver: bool, conditions: Tensor = None, training: bool = False, **kwargs, ) -> Tensor: """ Computes the velocity (i.e., time derivative) of the target or latent variable `xz` for either a stochastic differential equation (SDE) or ordinary differential equation (ODE). Parameters ---------- xz : Tensor The current state of the latent variable `z`, typically of shape (..., D), where D is the dimensionality of the latent space. time : float or Tensor Scalar or tensor representing the time (or noise level) at which the velocity should be computed. Will be broadcasted to xz. stochastic_solver : bool If True, computes the velocity for the stochastic formulation (SDE). If False, uses the deterministic formulation (ODE). conditions : Tensor, optional Conditional inputs to the network, such as conditioning variables or encoder outputs. Shape must be broadcastable with `xz`. Default is None. training : bool, optional Whether the model is in training mode. Affects behavior of dropout, batch norm, or other stochastic layers. Default is False. Returns ------- Tensor The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training, **kwargs) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift(log_snr_t=log_snr_t, x=xz, training=training) if stochastic_solver: # for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW out = f - g_squared * score else: # for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt out = f - 0.5 * g_squared * score # Zero out velocity where target is fixed (during inference only) if not training: target_mask = kwargs.get("target_mask", None) out = maybe_mask_tensor(out, mask=target_mask) return out
[docs] def diffusion_term( self, xz: Tensor, time: float | Tensor, training: bool = False, **kwargs, ) -> Tensor: """ Compute the diffusion term (standard deviation of the noise) at a given time. Parameters ---------- xz : Tensor Input tensor of shape (..., D), typically representing the target or latent variables at given time. time : float or Tensor The diffusion time step(s). Can be a scalar or a tensor broadcastable to the shape of `xz`. training : bool, optional Whether to use the training noise schedule (default is False). Returns ------- Tensor The diffusion term tensor with shape matching `xz` except for the last dimension, which is set to 1. """ log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) g_squared = self.noise_schedule.get_drift(log_snr_t=log_snr_t) g = ops.sqrt(g_squared) # Zero out diffusion where target is fixed (during inference only) if not training: target_mask = kwargs.get("target_mask", None) g = maybe_mask_tensor(g, mask=target_mask) return g
def _velocity_trace( self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False, **kwargs, ) -> tuple[Tensor, Tensor]: def f(x): return self.velocity( x, time=time, stochastic_solver=False, conditions=conditions, training=training, **kwargs ) v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True) return v, ops.expand_dims(trace, axis=-1) def _transform_log_snr(self, log_snr: Tensor) -> Tensor: """Transform the log_snr to the range [-1, 1] for the diffusion process.""" log_snr_min = self.noise_schedule.log_snr_min log_snr_max = self.noise_schedule.log_snr_max normalized_snr = (log_snr - log_snr_min) / (log_snr_max - log_snr_min) scaled_value = 2 * normalized_snr - 1 return scaled_value def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: # Note: integrators will cherry-pick necessary kwargs, so # we can be general (i.e., sloppy) here integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} integrate_kwargs |= self.integrate_kwargs integrate_kwargs |= kwargs if integrate_kwargs["method"] in STOCHASTIC_METHODS: logging.warning( "Stochastic methods are not supported for density evaluation." " Falling back to tsit5 ODE solver." " To suppress this warning, explicitly pass a method from " + str(DETERMINISTIC_METHODS) + "." ) integrate_kwargs["method"] = "tsit5" # Apply user-provided target mask if available target_mask = kwargs.get("target_mask", None) targets_fixed = kwargs.get("targets_fixed", None) if target_mask is not None: target_mask = keras.ops.broadcast_to(target_mask, keras.ops.shape(x)) targets_fixed = keras.ops.broadcast_to(targets_fixed, keras.ops.shape(x)) x = maybe_mask_tensor(x, target_mask, replacement=targets_fixed) if self.unconditional_mode and conditions is not None: conditions = keras.ops.zeros_like(conditions) logging.info("Condition masking is applied: conditions are set to zero.") if density: def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training, **kwargs) return {"xz": v, "trace": trace} state = { "xz": x, "trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)), } state = integrate(deltas, state, **integrate_kwargs) z = state["xz"] log_density = self.base_distribution.log_prob(z) + ops.squeeze(state["trace"], axis=-1) return z, log_density def deltas(time, xz): return { "xz": self.velocity( xz, time=time, stochastic_solver=False, conditions=conditions, training=training, **kwargs ) } state = {"xz": x} state = integrate(deltas, state, **integrate_kwargs) z = state["xz"] return z def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: # Build integrate kwargs: hardcoded defaults → instance config → call-time overrides integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs |= self.integrate_kwargs integrate_kwargs |= kwargs # Apply user-provided target mask if available target_mask = kwargs.get("target_mask", None) targets_fixed = kwargs.get("targets_fixed", None) if target_mask is not None: target_mask = keras.ops.broadcast_to(target_mask, keras.ops.shape(z)) targets_fixed = keras.ops.broadcast_to(targets_fixed, keras.ops.shape(z)) z = maybe_mask_tensor(z, target_mask, replacement=targets_fixed) if self.unconditional_mode and conditions is not None: conditions = keras.ops.zeros_like(conditions) logging.info("Condition masking is applied: conditions are set to zero.") if density: if integrate_kwargs["method"] in STOCHASTIC_METHODS: logging.warning( "Stochastic methods are not supported for density computation." " Falling back to ODE solver." " Use one of the deterministic methods: " + str(DETERMINISTIC_METHODS) + "." ) integrate_kwargs["method"] = "tsit5" def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training, **kwargs) return {"xz": v, "trace": trace} state = { "xz": z, "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), } state = integrate(deltas, state, **integrate_kwargs) x = state["xz"] log_density = self.base_distribution.log_prob(z) - ops.squeeze(state["trace"], axis=-1) return x, log_density state = {"xz": z} if integrate_kwargs["method"] in STOCHASTIC_METHODS: def deltas(time, xz): return { "xz": self.velocity( xz, time=time, stochastic_solver=True, conditions=conditions, training=training, **kwargs ) } def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training, **kwargs)} score_fn = None if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin": def score_fn(time, xz): return {"xz": self.score(xz, time=time, conditions=conditions, training=training, **kwargs)} state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, score_fn=score_fn, noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, ) else: def deltas(time, xz): return { "xz": self.velocity( xz, time=time, stochastic_solver=False, conditions=conditions, training=training, **kwargs ) } state = integrate(deltas, state, **integrate_kwargs) x = state["xz"] return x