Source code for bayesflow.networks.inference.diffusion.diffusion_model

from collections.abc import Sequence, Mapping
from typing import Any, Literal, Callable

import keras
from keras import ops

from bayesflow.types import Tensor, Shape
from bayesflow.utils import (
    expand_right_as,
    find_network,
    integrate,
    integrate_stochastic,
    jacobian_trace,
    layer_kwargs,
    logging,
    linsolve_batched,
    maybe_mask_tensor,
    random_mask,
    randomly_mask_along_axis,
    resolve_seed,
    weighted_mean,
    DETERMINISTIC_METHODS,
    STOCHASTIC_METHODS,
)
from bayesflow.utils.serialization import serialize, serializable

from .schedules.noise_schedule import NoiseSchedule
from .dispatch import find_noise_schedule

from ...inference import InferenceNetwork
from ...defaults import TIME_MLP_DEFAULTS, DIFFUSION_INTEGRATE_DEFAULTS


[docs] @serializable("bayesflow.networks") class DiffusionModel(InferenceNetwork): """Score-based diffusion model for simulation-based inference (SBI). Implements a score-based diffusion model with configurable subnet architecture, noise schedule, and prediction/loss types for amortized SBI as described in [1]. The diffusion model allows for post-hoc guidance (see [1]) and composition [2], implementing, stabilizing, and scaling initial ideas from [3] and [4]. Note that score-based diffusion is the most sluggish of all available samplers, so expect slower inference times than flow matching and much slower than normalizing flows. Parameters ---------- subnet : str, type, or keras.Layer, optional A neural network type for the diffusion model, will be instantiated using *subnet_kwargs*. If a string is provided, it should be a registered name (e.g., ``"time_mlp"``). If a type or ``keras.Layer`` is provided, it will be directly instantiated with the given *subnet_kwargs*. Any subnet must accept a tuple of tensors ``(target, time, conditions)``. noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional Noise schedule controlling the diffusion dynamics. Can be a string identifier, a schedule class, or a pre-initialised schedule instance. Default is ``"edm"``. prediction_type : {'velocity', 'noise', 'F', 'x', 'score', 'potential'}, optional Output format of the model's prediction. Default is ``"F"``. loss_type : {'velocity', 'noise', 'F'}, optional Loss function used to train the model. Default is ``"noise"``. subnet_kwargs : dict[str, Any], optional Additional keyword arguments passed to the subnet constructor. schedule_kwargs : dict[str, Any], optional Additional keyword arguments passed to the noise schedule constructor. integrate_kwargs : dict[str, Any], optional Configuration dictionary for the ODE/SDE integrator used at inference time. drop_cond_prob : float, optional Probability of dropping conditions during training (i.e., classifier-free guidance). Default is 0.0. drop_target_prob : float, optional Probability of dropping target values during training (i.e., learning arbitrary distributions). Default is 0.0. **kwargs Additional keyword arguments passed to the base ``InferenceNetwork``. References ---------- [1] Arruda, J., Bracher, N., Köthe, U., Hasenauer, J., & Radev, S. T. (2025). Diffusion Models in Simulation-Based Inference: A Tutorial Review. arXiv preprint arXiv:2512.20685. [2] Arruda et al., (2025). Compositional amortized inference for large-scale hierarchical Bayesian models. https://arxiv.org/abs/2505.14429 [3] Geffner et al., (2023). Compositional score modeling for simulation-based inference. https://arxiv.org/abs/2209.14249 [4] Linhart et al., (2024). Diffusion posterior sampling for simulation-based inference in tall data settings. https://arxiv.org/abs/2404.07593 """ def __init__( self, *, subnet: str | type | keras.Layer = "time_mlp", noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm", prediction_type: Literal["velocity", "noise", "F", "x", "score", "potential"] = "F", loss_type: Literal["velocity", "noise", "F"] = "noise", subnet_kwargs: dict[str, Any] = None, schedule_kwargs: dict[str, Any] = None, integrate_kwargs: dict[str, Any] = None, drop_cond_prob: float = 0.0, drop_target_prob: float = 0.0, **kwargs, ): super().__init__(base_distribution="normal", **kwargs) if prediction_type not in ["noise", "velocity", "F", "x", "score", "potential"]: raise ValueError(f"Unknown prediction type: {prediction_type}") if loss_type not in ["noise", "velocity", "F"]: raise ValueError(f"Unknown loss type: {loss_type}") if loss_type != "noise": logging.warning( "The standard schedules have weighting functions defined for the noise prediction loss. " "You might want to replace them if you are using a different loss function." ) self._prediction_type = prediction_type self._loss_type = loss_type schedule_kwargs = schedule_kwargs or {} self.noise_schedule = find_noise_schedule(noise_schedule, **schedule_kwargs) self.noise_schedule.validate() self.integrate_kwargs = DIFFUSION_INTEGRATE_DEFAULTS | (integrate_kwargs or {}) self.seed_generator = keras.random.SeedGenerator() subnet_kwargs = subnet_kwargs or {} if subnet == "time_mlp": subnet_kwargs = TIME_MLP_DEFAULTS | subnet_kwargs self.subnet = find_network(subnet, **subnet_kwargs) self.output_projector = None self.drop_cond_prob = drop_cond_prob self.unconditional_mode = False self.drop_target_prob = drop_target_prob self.compositional_bridge_d0 = 1.0 self.compositional_bridge_d1 = 1.0
[docs] def compute_metrics( self, x: Tensor | Sequence[Tensor], conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training", **kwargs, ) -> dict[str, Tensor]: subnet_kwargs = self._collect_mask_kwargs(self._SUBNET_MASK_KEYS, kwargs) training = stage == "training" noise_schedule_training_stage = stage == "training" or stage == "validation" if conditions is not None: conditions = randomly_mask_along_axis(conditions, self.drop_cond_prob, seed_generator=self.seed_generator) # Sample training diffusion time as a low discrepancy sequence to decrease variance u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x), seed=self.seed_generator) i = ops.arange(0, ops.shape(x)[0], dtype=ops.dtype(x)) t = (u0 + i / ops.cast(ops.shape(x)[0], dtype=ops.dtype(x))) % 1 # Calculate the noise level log_snr_t = self.noise_schedule.get_log_snr(t, training=noise_schedule_training_stage) log_snr_t = expand_right_as(log_snr_t, x) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) weights_for_snr = self.noise_schedule.get_weights_for_snr(log_snr_t=log_snr_t) # Generate noise vector eps_t = keras.random.normal(ops.shape(x), dtype=ops.dtype(x), seed=self.seed_generator) # Diffuse x to get noisy input to the network diffused_x = alpha_t * x + sigma_t * eps_t # Generate optional target dropout mask mask_x = random_mask(ops.shape(x), self.drop_target_prob, self.seed_generator) diffused_x = maybe_mask_tensor(diffused_x, mask=mask_x, replacement=x) # Obtain output of the network and transform to prediction of the clean signal x norm_log_snr_t = self._transform_log_snr(log_snr_t) if self._prediction_type == "potential": pred = self._compute_grad_of_potential( xz=diffused_x, norm_log_snr=norm_log_snr_t, conditions=conditions, training=training, **subnet_kwargs ) x_pred = self.convert_prediction_to_x( pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, prediction_type="velocity", ) else: subnet_out = self.subnet((diffused_x, norm_log_snr_t, conditions), training=training, **subnet_kwargs) pred = self.output_projector(subnet_out) x_pred = self.convert_prediction_to_x( pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, prediction_type=self._prediction_type, ) # Finally, compute the loss according to the configured loss type. Note that the standard weighting # functions are defined for the noise prediction loss, so if you use a different loss type, you might want # to adjust the weighting accordingly. match self._loss_type: case "noise": noise_pred = (diffused_x - alpha_t * x_pred) / sigma_t loss = weights_for_snr * ops.mean(mask_x * (noise_pred - eps_t) ** 2, axis=-1) case "velocity": velocity_pred = (alpha_t * diffused_x - x_pred) / sigma_t v_t = alpha_t * eps_t - sigma_t * x loss = weights_for_snr * ops.mean(mask_x * (velocity_pred - v_t) ** 2, axis=-1) case "F": sigma_data = self.noise_schedule.sigma_data if hasattr(self.noise_schedule, "sigma_data") else 1.0 x1 = ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) / (ops.exp(-log_snr_t / 2) * sigma_data) x2 = (sigma_data * alpha_t) / (ops.exp(-log_snr_t / 2) * ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)) f_pred = x1 * x_pred - x2 * diffused_x f_t = x1 * x - x2 * diffused_x loss = weights_for_snr * ops.mean(mask_x * (f_pred - f_t) ** 2, axis=-1) case _: raise ValueError(f"Unknown loss type: {self._loss_type}") loss = weighted_mean(loss, sample_weight) return {"loss": loss}
[docs] def build(self, xz_shape: Shape, conditions_shape: Shape = None): if self.built: return self.base_distribution.build(xz_shape) units = 1 if self._prediction_type == "potential" else xz_shape[-1] self.output_projector = keras.layers.Dense(units=units, bias_initializer="zeros") # construct input shape for subnet and subnet projector time_shape = (xz_shape[0], 1) self.subnet.build((xz_shape, time_shape, conditions_shape)) out_shape = self.subnet.compute_output_shape((xz_shape, time_shape, conditions_shape)) self.output_projector.build(out_shape)
[docs] def get_config(self): base_config = super().get_config() base_config = layer_kwargs(base_config) config = { "subnet": self.subnet, "noise_schedule": self.noise_schedule, "prediction_type": self._prediction_type, "loss_type": self._loss_type, "integrate_kwargs": self.integrate_kwargs, "drop_cond_prob": self.drop_cond_prob, "drop_target_prob": self.drop_target_prob, } return base_config | serialize(config)
[docs] def guidance_constraint_term( self, x: Tensor, time: Tensor, constraints: Callable | Sequence[Callable], guidance_strength: float = 1.0, scaling_function: Callable | None = None, reduce: Literal["sum", "mean"] = "sum", ) -> Tensor: """ Backend-agnostic implementation of: `∇_x Σ_k log sigmoid( -s(t) * c_k(x) )` Parameters ---------- x : Tensor The denoised target at time t. time : Tensor The time corresponding to x. constraints : Callable or Sequence[Callable] A single constraint function or a list/tuple of constraint functions. Each function should take x as input and return a tensor of constraint values. guidance_strength : float, optional A positive scaling factor for the guidance term. Default is 1.0. scaling_function : Callable, optional A function that takes time t as input and returns a scaling factor s(t). If None, a default scaling based on the noise schedule is used. Default is None. reduce : {'sum', 'mean'}, optional Method to reduce the log-probabilities from multiple constraints. Default is 'sum'. Returns ------- Tensor The computed guidance term of the same shape as zt. """ if not isinstance(constraints, Sequence): constraints = [constraints] if scaling_function is None: def scaling_function(t: Tensor): log_snr = self.noise_schedule.get_log_snr(t, training=False) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr) return ops.square(alpha_t) / ops.square(sigma_t) def objective_fn(z): st = scaling_function(time) logp = keras.ops.zeros((), dtype=z.dtype) for c in constraints: ck = c(z) logp = logp - keras.ops.softplus(st * ck) return keras.ops.sum(logp) if reduce == "sum" else keras.ops.mean(logp) backend = keras.backend.backend() match backend: case "jax": import jax grad = jax.grad(objective_fn)(x) case "tensorflow": import tensorflow as tf with tf.GradientTape() as tape: tape.watch(x) objective = objective_fn(x) grad = tape.gradient(objective, x) case "torch": import torch with torch.enable_grad(): x_grad = x.clone().detach().requires_grad_(True) objective = objective_fn(x_grad) grad = torch.autograd.grad( outputs=objective, inputs=x_grad, )[0] case _: raise NotImplementedError(f"Unsupported backend: {backend}") return guidance_strength * grad
[docs] def convert_prediction_to_x( self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor, prediction_type: str ) -> Tensor: """ Converts the neural network prediction into the denoised data `x`, depending on the prediction type configured for the model. Parameters ---------- pred : Tensor The output prediction from the neural network, typically representing noise, velocity, or a transformation of the clean signal. z : Tensor The noisy latent variable `z` to be denoised. alpha_t : Tensor The noise schedule's scaling factor for the clean signal at time `t`. sigma_t : Tensor The standard deviation of the noise at time `t`. log_snr_t : Tensor The log signal-to-noise ratio at time `t`. prediction_type : str The type of prediction made by the model, which determines how to convert it to `x`. Must be one of {'velocity', 'noise', 'F', 'x', 'score'}. Returns ------- Tensor The reconstructed clean signal `x` from the model prediction. """ match prediction_type: case "velocity": return alpha_t * z - sigma_t * pred case "noise": return (z - sigma_t * pred) / alpha_t case "F": sigma_data = getattr(self.noise_schedule, "sigma_data", 1.0) x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2) x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) return x1 * z + x2 * pred case "x": return pred case "score": return (z + sigma_t**2 * pred) / alpha_t case "potential": # Score is computed as ∇_z phi via _compute_grad_of_potential. raise RuntimeError("convert_prediction_to_x should not be called for prediction_type='potential'. ") case _: raise ValueError(f"Unknown prediction type {prediction_type}.")
[docs] def score( self, xz: Tensor, time: float | Tensor = None, log_snr_t: Tensor = None, conditions: Tensor = None, training: bool = False, guidance_constraints: Mapping[str, Any] = None, guidance_function: Callable[[Tensor, Tensor], Tensor] = None, **kwargs, ) -> Tensor: """ Computes the score of the target or latent variable `xz`. Parameters ---------- xz : Tensor The current state of the latent variable `z`, typically of shape (..., D), where D is the dimensionality of the latent space. time : float or Tensor Scalar or tensor representing the time (or noise level) at which the velocity should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. log_snr_t : Tensor The log signal-to-noise ratio at time `t`. If None, time must be provided. conditions : Tensor, optional Conditional inputs to the network, such as conditioning variables or encoder outputs. Shape must be broadcastable with `xz`. Default is None. training : bool, optional Whether the model is in training mode. Affects behavior of dropout, batch norm, or other stochastic layers. Default is False. guidance_constraints : dict[str, Any], optional A dictionary of parameters for computing a guidance constraint term, which is added to the score for guided sampling. The specific keys and values depend on the implementation of `guidance_constraint_term`. guidance_function : Callable[[Tensor, Tensor], Tensor], optional A custom function for computing a guidance term, which is added to the score for guided sampling. The function should accept the predicted clean signal `x_pred` and the current time `time` as inputs and return a tensor of the same shape as `xz`. **kwargs Subnet kwargs (e.g., attention_mask, mask) for the subnet layer. Also supports guidance_constraints and guidance_function for custom guidance. Returns ------- Tensor The velocity tensor of the same shape as `xz`, representing the right-hand side of the probability-flow SDE or ODE at the given `time`. """ subnet_kwargs = self._collect_mask_kwargs(self._SUBNET_MASK_KEYS, kwargs) if log_snr_t is None: log_snr_t = self.noise_schedule.get_log_snr(t=time, training=training) log_snr_t = expand_right_as(log_snr_t, xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) if time is None: time = self.noise_schedule.get_t_from_log_snr(log_snr_t, training=training) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) norm_log_snr = self._transform_log_snr(log_snr_t) if self._prediction_type == "potential": pred = self._compute_grad_of_potential( xz=xz, norm_log_snr=norm_log_snr, conditions=conditions, training=training, **subnet_kwargs ) x_pred = self.convert_prediction_to_x( pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, prediction_type="velocity" ) else: subnet_out = self.subnet((xz, norm_log_snr, conditions), training=training, **subnet_kwargs) pred = self.output_projector(subnet_out) x_pred = self.convert_prediction_to_x( pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, prediction_type=self._prediction_type, ) score = (alpha_t * x_pred - xz) / ops.square(sigma_t) if guidance_constraints is not None: guidance = self.guidance_constraint_term(x=x_pred, time=time, **guidance_constraints) score = score + guidance if guidance_function is not None: guidance = guidance_function(x=x_pred, time=time) score = score + guidance return score
[docs] def velocity( self, xz: Tensor, time: float | Tensor, stochastic_solver: bool, conditions: Tensor = None, training: bool = False, **kwargs, ) -> Tensor: """ Computes the velocity (i.e., time derivative) of the target or latent variable `xz` for either a stochastic differential equation (SDE) or ordinary differential equation (ODE). Parameters ---------- xz : Tensor The current state of the latent variable `z`, typically of shape (..., D), where D is the dimensionality of the latent space. time : float or Tensor Scalar or tensor representing the time (or noise level) at which the velocity should be computed. Will be broadcasted to xz. stochastic_solver : bool If True, computes the velocity for the stochastic formulation (SDE). If False, uses the deterministic formulation (ODE). conditions : Tensor, optional Conditional inputs to the network, such as conditioning variables or encoder outputs. Shape must be broadcastable with `xz`. Default is None. training : bool, optional Whether the model is in training mode. Affects behavior of dropout, batch norm, or other stochastic layers. Default is False. Returns ------- Tensor The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training, **kwargs) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift(log_snr_t=log_snr_t, x=xz, training=training) if stochastic_solver: # for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW out = f - g_squared * score else: # for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt out = f - 0.5 * g_squared * score # Zero out velocity where target is fixed (during inference only) if not training: target_mask = kwargs.get("target_mask", None) out = maybe_mask_tensor(out, mask=target_mask) return out
[docs] def diffusion_term( self, xz: Tensor, time: float | Tensor, training: bool = False, **kwargs, ) -> Tensor: """ Compute the diffusion term (standard deviation of the noise) at a given time. Parameters ---------- xz : Tensor Input tensor of shape (..., D), typically representing the target or latent variables at given time. time : float or Tensor The diffusion time step(s). Can be a scalar or a tensor broadcastable to the shape of `xz`. training : bool, optional Whether to use the training noise schedule (default is False). Returns ------- Tensor The diffusion term tensor with shape matching `xz` except for the last dimension, which is set to 1. """ log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) g_squared = self.noise_schedule.get_drift(log_snr_t=log_snr_t) g = ops.sqrt(g_squared) # Zero out diffusion where target is fixed (during inference only) if not training: target_mask = kwargs.get("target_mask", None) g = maybe_mask_tensor(g, mask=target_mask) return g
def _velocity_trace( self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False, **kwargs, ) -> tuple[Tensor, Tensor]: def f(x): return self.velocity( x, time=time, stochastic_solver=False, conditions=conditions, training=training, **kwargs ) v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True) return v, ops.expand_dims(trace, axis=-1) def _compute_grad_of_potential( self, xz: Tensor, norm_log_snr: Tensor, conditions: Tensor = None, training: bool = False, **subnet_kwargs, ) -> Tensor: """ Compute the gradient of the predicted potential phi w.r.t. xz. The network outputs a scalar potential phi(z, t, cond) per sample. The output is then defined as ∇_z phi, i.e. the gradient of the potential w.r.t. the noisy input. Parameters ---------- xz : Tensor The noisy latent variable z at which to evaluate the score. norm_log_snr : Tensor Normalised log-SNR fed to the subnet as the time embedding. conditions : Tensor, optional Conditioning variables passed to the subnet. Default is None. training : bool, optional Whether the subnet is in training mode. **subnet_kwargs Additional keyword arguments forwarded to the subnet (e.g. attention masks). Returns ------- Tensor Output of the same shape as *xz*, equal to some parameterization of ∇_xz phi(xz, t, cond). """ backend = keras.backend.backend() match backend: case "jax": import jax import jax.numpy as jnp def phi_fn(x): subnet_out = self.subnet( (x, norm_log_snr, conditions), training=training, **subnet_kwargs, ) return self.output_projector(subnet_out) phi, vjp_fn = jax.vjp(phi_fn, xz) out = vjp_fn(jnp.ones_like(phi))[0] return out case "tensorflow": import tensorflow as tf with tf.GradientTape() as tape: tape.watch(xz) subnet_out = self.subnet( (xz, norm_log_snr, conditions), training=training, **subnet_kwargs, ) phi = self.output_projector(subnet_out) out = tape.gradient( target=phi, sources=xz, output_gradients=tf.ones_like(phi), ) return out case "torch": import torch with torch.enable_grad(): xz_leaf = xz.requires_grad_(True) subnet_out = self.subnet( (xz_leaf, norm_log_snr, conditions), training=training, **subnet_kwargs, ) phi = self.output_projector(subnet_out) out = torch.autograd.grad( outputs=phi, inputs=xz_leaf, grad_outputs=torch.ones_like(phi), create_graph=training, retain_graph=training, )[0] return out case _: raise NotImplementedError(f"Unsupported backend for potential prediction type: {backend}") def _transform_log_snr(self, log_snr: Tensor) -> Tensor: """Transform the log_snr to the range [-1, 1] for the diffusion process.""" log_snr_min = self.noise_schedule.log_snr_min log_snr_max = self.noise_schedule.log_snr_max normalized_snr = (log_snr - log_snr_min) / (log_snr_max - log_snr_min) scaled_value = 2 * normalized_snr - 1 return scaled_value def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: # Note: integrators will cherry-pick necessary kwargs, so # we can be general (i.e., sloppy) here integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} integrate_kwargs |= self.integrate_kwargs integrate_kwargs |= kwargs if integrate_kwargs["method"] in STOCHASTIC_METHODS: logging.warning( "Stochastic methods are not supported for density evaluation." " Falling back to tsit5 ODE solver." " To suppress this warning, explicitly pass a method from " + str(DETERMINISTIC_METHODS) + "." ) integrate_kwargs["method"] = "tsit5" # Apply user-provided target mask if available target_mask = kwargs.get("target_mask", None) targets_fixed = kwargs.get("targets_fixed", None) if target_mask is not None: target_mask = keras.ops.broadcast_to(target_mask, keras.ops.shape(x)) targets_fixed = keras.ops.broadcast_to(targets_fixed, keras.ops.shape(x)) x = maybe_mask_tensor(x, target_mask, replacement=targets_fixed) if self.unconditional_mode and conditions is not None: conditions = keras.ops.zeros_like(conditions) logging.info("Condition masking is applied: conditions are set to zero.") if density: def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training, **kwargs) return {"xz": v, "trace": trace} state = { "xz": x, "trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)), } state = integrate(deltas, state, **integrate_kwargs) z = state["xz"] log_density = self.base_distribution.log_prob(z) + ops.squeeze(state["trace"], axis=-1) return z, log_density def deltas(time, xz): return { "xz": self.velocity( xz, time=time, stochastic_solver=False, conditions=conditions, training=training, **kwargs ) } state = {"xz": x} state = integrate(deltas, state, **integrate_kwargs) z = state["xz"] return z def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: seed = resolve_seed(kwargs.pop("seed", None)) or self.seed_generator # Build integrate kwargs: hardcoded defaults -> instance config -> call-time overrides integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs |= self.integrate_kwargs integrate_kwargs |= kwargs # Apply user-provided target mask if available target_mask = kwargs.get("target_mask", None) targets_fixed = kwargs.get("targets_fixed", None) if target_mask is not None: target_mask = keras.ops.broadcast_to(target_mask, keras.ops.shape(z)) targets_fixed = keras.ops.broadcast_to(targets_fixed, keras.ops.shape(z)) z = maybe_mask_tensor(z, target_mask, replacement=targets_fixed) if self.unconditional_mode and conditions is not None: conditions = keras.ops.zeros_like(conditions) logging.info("Condition masking is applied: conditions are set to zero.") if density: if integrate_kwargs["method"] in STOCHASTIC_METHODS: logging.warning( "Stochastic methods are not supported for density computation." " Falling back to ODE solver." " Use one of the deterministic methods: " + str(DETERMINISTIC_METHODS) + "." ) integrate_kwargs["method"] = "tsit5" def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training, **kwargs) return {"xz": v, "trace": trace} state = { "xz": z, "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), } state = integrate(deltas, state, **integrate_kwargs) x = state["xz"] log_density = self.base_distribution.log_prob(z) - ops.squeeze(state["trace"], axis=-1) return x, log_density state = {"xz": z} if integrate_kwargs["method"] in STOCHASTIC_METHODS: def deltas(time, xz): return { "xz": self.velocity( xz, time=time, stochastic_solver=True, conditions=conditions, training=training, **kwargs ) } def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training, **kwargs)} score_fn = None if "corrector_steps" in integrate_kwargs or integrate_kwargs["method"] == "langevin": def score_fn(time, xz): return {"xz": self.score(xz, time=time, conditions=conditions, training=training, **kwargs)} state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, score_fn=score_fn, noise_schedule=self.noise_schedule, state=state, seed=seed, **integrate_kwargs, ) else: def deltas(time, xz): return { "xz": self.velocity( xz, time=time, stochastic_solver=False, conditions=conditions, training=training, **kwargs ) } state = integrate(deltas, state, **integrate_kwargs) x = state["xz"] return x
[docs] def compositional_bridge(self, time: Tensor) -> Tensor: """Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 = d1 = 1. Otherwise, it can be used to stabilize the compositional score over time. Parameters ---------- time: Tensor Time step for the diffusion process. Returns ------- Tensor Bridge function value with same shape as time. """ return self.compositional_bridge_d0 * ops.exp( -ops.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time )
[docs] def compositional_velocity( self, xz: Tensor, time: float | Tensor, stochastic_solver: bool, conditions: Tensor, seed: keras.random.SeedGenerator, compute_prior_score: Callable[[Tensor], Tensor] = None, mini_batch_size: int | None = None, training: bool = False, **kwargs, ) -> Tensor: """ Computes the compositional velocity for multiple datasets using the formula: s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) Parameters ---------- xz : Tensor The current state of the latent variable, shape (num_datasets, num_items, ...) time : float or Tensor Time step for the diffusion process stochastic_solver : bool Whether to use stochastic (SDE) or deterministic (ODE) formulation conditions : Tensor Conditional inputs with compositional structure (num_datasets, num_items, ...) seed: SeedGenerator For reproducibility. compute_prior_score: Callable, optional Function to compute the prior score ∇_θ log p(θ). Otherwise, the unconditional score is used. mini_batch_size : int or None Mini batch size for computing individual scores. If None, use all conditions. training : bool, optional Whether in training mode **kwargs Additional keyword arguments passed to the individual score computation. Returns ------- Tensor Compositional velocity of same shape as input xz """ compositional_score = self.compositional_score( xz=xz, time=time, conditions=conditions, compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, seed=seed, **kwargs, ) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) # Compute velocity using standard drift-diffusion formulation f, g_squared = self.noise_schedule.get_drift(log_snr_t=log_snr_t, x=xz, training=training) if stochastic_solver: # for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW out = f - g_squared * compositional_score else: # for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt out = f - 0.5 * g_squared * compositional_score # Zero out velocity where target is fixed (during inference only) if not training: target_mask = kwargs.get("target_mask", None) out = maybe_mask_tensor(out, mask=target_mask) return out
[docs] def compositional_score( self, xz: Tensor, time: float | Tensor, conditions: Tensor, seed: keras.random.SeedGenerator | None, compute_prior_score: Callable[[Tensor], Tensor] = None, mini_batch_size: int | None = None, training: bool = False, clip: tuple[float, float] | None = (-3, 3), use_jac: bool = False, guidance_constraints: Mapping[str, Any] = None, guidance_function: Callable[[Tensor, Tensor], Tensor] = None, **kwargs, ) -> Tensor: """ Computes the compositional score for multiple datasets. Parameters ---------- xz : Tensor The current state of the latent variable, shape (num_datasets, num_items, ...) time : float or Tensor Time step for the diffusion process conditions : Tensor Conditional inputs with compositional structure (num_datasets, num_items, ...) seed: keras.random.SeedGenerator or None Optional seed for reproducibility. compute_prior_score: Callable, optional Function to compute the prior score ∇_θ log p(θ). Otherwise, the unconditional score is used. mini_batch_size : int or None Mini batch size for computing individual scores. If None, use all conditions. training : bool, optional Whether in training mode clip: (float, float), optional Whether to clip the predicted x for numerical stability at given values. use_jac: bool, optional Whether to use the Jacobian-based compositional score instead of the direct sum. guidance_constraints : dict[str, Any], optional A dictionary of parameters for computing a guidance constraint term, which is added to the score for guided sampling. The specific keys and values depend on the implementation of `guidance_constraint_term`. guidance_function : Callable[[Tensor, Tensor], Tensor], optional A custom function for computing a guidance term, which is added to the score for guided sampling. The function should accept the predicted clean signal `x_pred` and the current time `time` as inputs and return a tensor of the same shape as `xz`. **kwargs Additional keyword arguments passed to the individual score computation Returns ------- Tensor Compositional velocity of same shape as input xz """ if conditions is None: raise ValueError("Conditions are required for compositional sampling") log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) time = ops.cast(time, dtype=ops.dtype(xz)) if not use_jac: compositional_score = self._compositional_score_direct( xz=xz, time=time, log_snr_t=log_snr_t, conditions=conditions, compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, seed=seed, **kwargs, ) else: # uses a Gaussian approximation, can be very slow compositional_score = self._compositional_score_jac( xz=xz, time=time, log_snr_t=log_snr_t, sigma_t=sigma_t, conditions=conditions, compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, seed=seed, **kwargs, ) compositional_score = self.compositional_bridge(time) * compositional_score if guidance_constraints is not None or guidance_function is not None: # x_pred = (z + sigma_t ** 2 * score) / alpha_t x_pred = (xz + sigma_t**2 * compositional_score) / alpha_t if guidance_constraints is not None: guidance = self.guidance_constraint_term(x=x_pred, time=time, **guidance_constraints) compositional_score = compositional_score + guidance if guidance_function is not None: guidance = guidance_function(x=x_pred, time=time) compositional_score = compositional_score + guidance compositional_score = self._maybe_clip_score(compositional_score, clip, alpha_t, sigma_t, xz) return compositional_score
def _compositional_score_direct( self, xz: Tensor, time: float | Tensor, log_snr_t: Tensor, conditions: Tensor, seed: keras.random.SeedGenerator | None = None, compute_prior_score: Callable[[Tensor], Tensor] = None, mini_batch_size: int | None = None, training: bool = False, **kwargs, ) -> Tensor: """ Computes the compositional score for multiple datasets using the formula: s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) with possible weighting of the scores. Parameters ---------- xz : Tensor The current state of the latent variable, shape (num_datasets, num_items, ...) time : float or Tensor Time step for the diffusion process. log_snr_t : Tensor Log SNR at time t, broadcastable to shape of xz. conditions : Tensor Conditional inputs with compositional structure (num_datasets, num_items, ...) seed: keras.random.SeedGenerator or None Optional seed for reproducibility. compute_prior_score: Callable, optional Function to compute the prior score ∇_θ log p(θ). Otherwise, the unconditional score is estimated. mini_batch_size : int or None Mini batch size for computing individual scores. If None, use all conditions. training : bool, optional Whether in training mode. **kwargs Additional keyword arguments passed to the individual score computation. Returns ------- Tensor Compositional score of same shape as input xz """ batch_size, num_items = ops.shape(conditions)[:2] # Sample item indices for mini-batching or keep all items if mini_batch_size is not None and mini_batch_size < num_items: ranks = keras.random.uniform((batch_size, num_items), seed=seed) per_row_idx = ops.top_k(-ranks, mini_batch_size).indices conditions_batch = ops.take_along_axis(conditions, per_row_idx[..., None], axis=1) else: conditions_batch = conditions mini_batch_size = num_items # Determine scale of summed posterior score needs_network_prior = compute_prior_score is None if needs_network_prior: zero_cond = ops.zeros_like(ops.take(conditions, 0, axis=1)) cond_with_prior = ops.concatenate([conditions_batch, ops.expand_dims(zero_cond, 1)], axis=1) num_total = mini_batch_size + 1 else: cond_with_prior = conditions_batch num_total = mini_batch_size scale = num_items / mini_batch_size # Expand and flatten compositional dimension (i.e., num items) for score computation dims = tuple(ops.shape(xz)[1:]) snr_dims = tuple(ops.shape(log_snr_t)[1:]) conditions_dims = tuple(ops.shape(cond_with_prior)[2:]) xz_reshaped = ops.reshape( ops.repeat(ops.expand_dims(xz, 1), num_total, axis=1), (batch_size * num_total,) + dims ) log_snr_reshaped = ops.reshape( ops.repeat(ops.expand_dims(log_snr_t, 1), num_total, axis=1), (batch_size * num_total,) + snr_dims, ) conditions_flat = ops.reshape(cond_with_prior, (batch_size * num_total,) + conditions_dims) scores_flat = self.score( xz_reshaped, log_snr_t=log_snr_reshaped, conditions=conditions_flat, training=training, **kwargs, ) all_scores = ops.reshape(scores_flat, (batch_size, num_total) + dims) individual_scores = all_scores[:, :mini_batch_size] if needs_network_prior: prior_score = all_scores[:, -1] else: # internally uses a (1-time) weight if prior score has no time argument prior_score = compute_prior_score(xz, time) # Combined score using compositional formula: (1-n) prior_score + Σᵢ₌₁ⁿ posterior_score delta = individual_scores - ops.expand_dims(prior_score, axis=1) update_delta = scale * ops.sum(delta, axis=1) compositional_score = prior_score + update_delta return compositional_score def _compositional_score_jac( self, xz: Tensor, time: float | Tensor, log_snr_t: Tensor, sigma_t: Tensor, conditions: Tensor, seed: keras.random.SeedGenerator | None = None, compute_prior_score: Callable[[Tensor], Tensor] | None = None, mini_batch_size: int | None = None, training: bool = False, regularize_precision: float = 1e-6, **kwargs, ) -> Tensor: """Stabilized version of JAC compositional score (Linhart et al. 2026) with mini-batching from (Arruda et al. 2026). Precision-weighted compositional posterior score where both individual posterior precisions and the prior precision are estimated via the Jacobian of the score network. Parameters ---------- xz : Tensor The current state of the latent variable, shape (num_datasets, num_items, ...) time : float or Tensor Time step for the diffusion process. log_snr_t : Tensor Log SNR at time t, broadcastable to shape of xz. sigma_t : Tensor Sigma component of noise schedule at time t, broadcastable to shape of xz. conditions : Tensor Conditional inputs with compositional structure (num_datasets, num_items, ...) seed: SeedGenerator or None Optional seed for reproducibility. compute_prior_score: Callable, optional Function to compute the prior score ∇_θ log p(θ). Otherwise, the unconditional score is used. mini_batch_size : int or None, optional Mini batch size for computing individual scores. If None, use all conditions. training : bool Whether in training mode. regularize_precision : float Tikhonov regularization added to Λ before solving for numerical stability. **kwargs Additional keyword arguments passed to the individual score computation. Returns ------- Tensor Compositional score of same shape as input xz """ batch_size, num_items = ops.shape(conditions)[:2] m = ops.shape(xz)[-1] upsilon_t = ops.maximum(ops.reshape(sigma_t, (-1,))[0] ** 2, 1e-8) if mini_batch_size is not None and mini_batch_size < num_items: ranks = keras.random.uniform((batch_size, num_items), seed=seed) per_row_idx = ops.top_k(-ranks, mini_batch_size).indices conditions_batch = ops.take_along_axis(conditions, per_row_idx[..., None], axis=1) else: conditions_batch = conditions mini_batch_size = num_items # Append prior as extra observation (zero-conditioned) zero_cond = ops.zeros_like(ops.take(conditions, 0, axis=1)) cond_with_prior = ops.concatenate([conditions_batch, ops.expand_dims(zero_cond, 1)], axis=1) num_total = mini_batch_size + 1 dims = tuple(ops.shape(xz)[1:]) snr_dims = tuple(ops.shape(log_snr_t)[1:]) cond_dims = tuple(ops.shape(cond_with_prior)[2:]) # Repeat xz and log_snr_t for each observation+prior xz_rep = ops.reshape( ops.repeat(ops.expand_dims(xz, 1), num_total, axis=1), (batch_size * num_total,) + dims, ) lsnr_rep = ops.reshape( ops.repeat(ops.expand_dims(log_snr_t, 1), num_total, axis=1), (batch_size * num_total,) + snr_dims, ) cond_flat = ops.reshape( cond_with_prior, (batch_size * num_total,) + cond_dims, ) all_scores, all_jacs = self._compute_score_and_jacobian( xz=xz_rep, log_snr_t=lsnr_rep, conditions=cond_flat, training=training, **kwargs, ) # Reshape: (B, num_total, m) and (B, num_total, m, m) all_scores = ops.reshape(all_scores, (batch_size, num_total, m)) all_jacs = ops.reshape(all_jacs, (batch_size, num_total, m, m)) # compute P_k = (I + υ J_k)⁻¹ for all k at once I_m = ops.eye(m, dtype=ops.dtype(xz)) I_4d = ops.broadcast_to( ops.reshape(I_m, (1, 1, m, m)), (batch_size, num_total, m, m), ) A_all = I_4d + upsilon_t * all_jacs P_all = ops.inv(A_all) # P_k @ s_k for all k: (B, num_total, m, 1) → (B, num_total, m) scores_col = ops.expand_dims(all_scores, -1) Ps_all = ops.squeeze(ops.matmul(P_all, scores_col), axis=-1) # (B, num_total, m) # split observations and prior P_obs = P_all[:, :mini_batch_size] # (B, n_obs, m, m) Ps_obs = Ps_all[:, :mini_batch_size] # (B, n_obs, m) P_lambda = P_all[:, -1] # (B, m, m) # Sum over observations sum_P = ops.sum(P_obs, axis=1) # (B, m, m) sum_Ps = ops.sum(Ps_obs, axis=1) # (B, m) # Scale for mini-batch estimator if mini_batch_size < num_items: scale = num_items / mini_batch_size sum_P = scale * sum_P sum_Ps = scale * sum_Ps # prior score if compute_prior_score is not None: s_lambda = compute_prior_score(xz, time) P_lambda_s = ops.squeeze(ops.matmul(P_lambda, ops.expand_dims(s_lambda, -1)), axis=-1) else: P_lambda_s = Ps_all[:, -1] # already computed # solve: Λ x = s̃ w_prior = 1 - num_items I_m_batch = ops.broadcast_to(I_m, (batch_size, m, m)) lambda_matrix = sum_P + w_prior * P_lambda + regularize_precision * I_m_batch s_tilde = sum_Ps + w_prior * P_lambda_s compositional_score = linsolve_batched(lambda_matrix, s_tilde) return compositional_score def _compute_score_and_jacobian( self, xz: Tensor, log_snr_t: Tensor, conditions: Tensor, training: bool = False, **kwargs, ) -> tuple[Tensor, Tensor]: """Compute score and JAC precision Σ̂⁻¹ for one conditioning input using full Jacobian J = ∇_{θ_t} s(θ_t). Returns ------- s : Tensor (B, m) prec : Tensor (B, m, m) Σ̂⁻¹ = (α_t/υ_t)(I + υ_t J)⁻¹ """ m = ops.shape(xz)[-1] backend = keras.backend.backend() match backend: case "jax": import jax import jax.numpy as jnp def phi_fn(x): return self.score( xz=x, log_snr_t=log_snr_t, conditions=conditions, training=training, **kwargs, ) score, vjp_fn = jax.vjp(phi_fn, xz) rows = [] for i in range(m): e_i = jnp.zeros_like(score).at[:, i].set(1.0) (row_i,) = vjp_fn(e_i) rows.append(jnp.expand_dims(row_i, axis=1)) jacobian = jnp.concatenate(rows, axis=1) case "tensorflow": import tensorflow as tf with tf.GradientTape(persistent=True) as tape: tape.watch(xz) score = self.score( xz=xz, log_snr_t=log_snr_t, conditions=conditions, training=training, **kwargs, ) if m <= 32: # for larger dimension rows = [] for i in range(m): e_i = tf.one_hot(tf.fill([tf.shape(score)[0]], i), depth=m, dtype=score.dtype) row_i = tape.gradient(score, xz, output_gradients=e_i) rows.append(tf.expand_dims(row_i, axis=1)) jacobian = tf.concat(rows, axis=1) else: jacobian = tape.batch_jacobian(score, xz, experimental_use_pfor=False) del tape case "torch": import torch with torch.enable_grad(): xz_leaf = xz.detach().requires_grad_(True) score = self.score( xz=xz_leaf, log_snr_t=log_snr_t, conditions=conditions, training=training, **kwargs, ) rows = [] for i in range(m): e_i = torch.zeros_like(score) e_i[:, i] = 1.0 (row_i,) = torch.autograd.grad( outputs=score, inputs=xz_leaf, grad_outputs=e_i, create_graph=False, retain_graph=(i < m - 1), ) rows.append(row_i.unsqueeze(1)) jacobian = torch.cat(rows, dim=1) score = score.detach() case _: raise NotImplementedError(f"JAC Jacobian not implemented for backend '{backend}'. ") return score, jacobian @staticmethod def _maybe_clip_score(compositional_score, clip, alpha_t, sigma_t, xz): if clip is not None: min_clip, max_clip = clip x_pred = (xz + sigma_t**2 * compositional_score) / alpha_t x_pred = ops.clip(x_pred, min_clip, max_clip) compositional_score = (x_pred * alpha_t - xz) / (sigma_t**2) return compositional_score def _inverse_compositional( self, z: Tensor, conditions: Tensor, compute_prior_score: Callable[[Tensor], Tensor] = None, density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: """ Inverse pass for compositional diffusion sampling. """ seed = resolve_seed(kwargs.pop("seed", None)) or self.seed_generator integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs |= self.integrate_kwargs integrate_kwargs |= kwargs num_items = ops.shape(conditions)[1] mini_batch_size = int(integrate_kwargs.pop("mini_batch_size", max(num_items * 0.1, 2))) if "mini_batch_size" in kwargs: kwargs.pop("mini_batch_size") if mini_batch_size is None: mini_batch_size = num_items mini_batch_size = max(mini_batch_size, 1) if keras.backend.backend() == "jax" and mini_batch_size != num_items: mini_batch_size = num_items logging.warning("Setting mini_batch_size to num_items as jax does not support mini-batching.") self.compositional_bridge_d0 = float( integrate_kwargs.pop("compositional_bridge_d0", self.compositional_bridge_d0) ) self.compositional_bridge_d1 = float( integrate_kwargs.pop("compositional_bridge_d1", self.compositional_bridge_d1) ) if integrate_kwargs["method"] == "langevin": # Geffner et al. (2023) z_scaling = num_items * self.compositional_bridge_d1 z = z / ops.sqrt(ops.cast(z_scaling, dtype=ops.dtype(z))) # Apply user-provided target mask if available target_mask = kwargs.get("target_mask", None) targets_fixed = kwargs.get("targets_fixed", None) if target_mask is not None: target_mask = keras.ops.broadcast_to(target_mask, keras.ops.shape(z)) targets_fixed = keras.ops.broadcast_to(targets_fixed, keras.ops.shape(z)) z = maybe_mask_tensor(z, target_mask, replacement=targets_fixed) if density: if integrate_kwargs["method"] in STOCHASTIC_METHODS: logging.warning( "Stochastic methods are not supported for density computation." " Falling back to ODE solver." " Use one of the deterministic methods: " + str(DETERMINISTIC_METHODS) + "." ) integrate_kwargs["method"] = "tsit5" def deltas(time, xz): v = self.compositional_velocity( xz, time=time, stochastic_solver=False, conditions=conditions, compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, seed=seed, **kwargs, ) trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) return {"xz": v, "trace": trace} state = { "xz": z, "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), } state = integrate(deltas, state, **integrate_kwargs) x = state["xz"] log_density = self.base_distribution.log_prob(z) - ops.squeeze(state["trace"], axis=-1) return x, log_density state = {"xz": z} if integrate_kwargs["method"] in STOCHASTIC_METHODS: def deltas(time, xz): return { "xz": self.compositional_velocity( xz, time=time, stochastic_solver=True, conditions=conditions, compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, seed=seed, **kwargs, ) } def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training, **kwargs)} score_fn = None if "corrector_steps" in integrate_kwargs or integrate_kwargs["method"] == "langevin": def score_fn(time, xz): return { "xz": self.compositional_score( xz=xz, time=time, conditions=conditions, compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, seed=seed, **kwargs, ) } state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, score_fn=score_fn, noise_schedule=self.noise_schedule, state=state, seed=seed, **integrate_kwargs, ) else: def deltas(time, xz): return { "xz": self.compositional_velocity( xz, time=time, stochastic_solver=False, conditions=conditions, compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, seed=seed, **kwargs, ) } state = integrate(deltas, state, **integrate_kwargs) x = state["xz"] return x