Source code for bayesflow.experimental.diffusion_model.diffusion_model

from collections.abc import Sequence
from typing import Literal

import keras
from keras import ops

from bayesflow.networks import InferenceNetwork
from bayesflow.types import Tensor, Shape
from bayesflow.utils import (
    expand_right_as,
    find_network,
    jacobian_trace,
    layer_kwargs,
    weighted_mean,
    integrate,
    integrate_stochastic,
    logging,
    tensor_utils,
)
from bayesflow.utils.serialization import serialize, deserialize, serializable

from .schedules.noise_schedule import NoiseSchedule
from .dispatch import find_noise_schedule


# disable module check, use potential module after moving from experimental
[docs] @serializable("bayesflow.networks", disable_module_check=True) class DiffusionModel(InferenceNetwork): """Diffusion Model as described in this overview paper [1]. [1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data Augmentation: Kingma et al. (2023) [2] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021) """ MLP_DEFAULT_CONFIG = { "widths": (256, 256, 256, 256, 256), "activation": "mish", "kernel_initializer": "he_normal", "residual": True, "dropout": 0.0, "spectral_normalization": False, } INTEGRATE_DEFAULT_CONFIG = { "method": "euler", "steps": 100, } def __init__( self, *, subnet: str | type | keras.Layer = "mlp", noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm", prediction_type: Literal["velocity", "noise", "F", "x"] = "F", loss_type: Literal["velocity", "noise", "F"] = "noise", subnet_kwargs: dict[str, any] = None, schedule_kwargs: dict[str, any] = None, integrate_kwargs: dict[str, any] = None, **kwargs, ): """ Initializes a diffusion model with configurable subnet architecture, noise schedule, and prediction/loss types for amortized Bayesian inference. Note, that score-based diffusion is the most sluggish of all available samplers, so expect slower inference times than flow matching and much slower than normalizing flows. Parameters ---------- subnet : str, type or keras.Layer, optional Architecture for the transformation network. Can be "mlp", a custom network class, or a Layer object, e.g., `bayesflow.networks.MLP(widths=[32, 32])`. Default is "mlp". noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional Noise schedule controlling the diffusion dynamics. Can be a string identifier, a schedule class, or a pre-initialized schedule instance. Default is "edm". prediction_type : {'velocity', 'noise', 'F', 'x'}, optional Output format of the model's prediction. Default is "F". loss_type : {'velocity', 'noise', 'F'}, optional Loss function used to train the model. Default is "noise". subnet_kwargs : dict[str, any], optional Additional keyword arguments passed to the subnet constructor. Default is None. schedule_kwargs : dict[str, any], optional Additional keyword arguments passed to the noise schedule constructor. Default is None. integrate_kwargs : dict[str, any], optional Configuration dictionary for integration during training or inference. Default is None. **kwargs Additional keyword arguments passed to the base class and internal components. """ super().__init__(base_distribution="normal", **kwargs) if prediction_type not in ["noise", "velocity", "F", "x"]: raise ValueError(f"Unknown prediction type: {prediction_type}") if loss_type not in ["noise", "velocity", "F"]: raise ValueError(f"Unknown loss type: {loss_type}") if loss_type != "noise": logging.warning( "The standard schedules have weighting functions defined for the noise prediction loss. " "You might want to replace them if you are using a different loss function." ) self._prediction_type = prediction_type self._loss_type = loss_type schedule_kwargs = schedule_kwargs or {} self.noise_schedule = find_noise_schedule(noise_schedule, **schedule_kwargs) self.noise_schedule.validate() self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {}) self.seed_generator = keras.random.SeedGenerator() subnet_kwargs = subnet_kwargs or {} if subnet == "mlp": subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs self.subnet = find_network(subnet, **subnet_kwargs) self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
[docs] def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: if self.built: return self.base_distribution.build(xz_shape) self.output_projector.units = xz_shape[-1] input_shape = list(xz_shape) # construct time vector input_shape[-1] += 1 if conditions_shape is not None: input_shape[-1] += conditions_shape[-1] input_shape = tuple(input_shape) self.subnet.build(input_shape) out_shape = self.subnet.compute_output_shape(input_shape) self.output_projector.build(out_shape)
[docs] def get_config(self): base_config = super().get_config() base_config = layer_kwargs(base_config) config = { "subnet": self.subnet, "noise_schedule": self.noise_schedule, "prediction_type": self._prediction_type, "loss_type": self._loss_type, "integrate_kwargs": self.integrate_kwargs, } return base_config | serialize(config)
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))
[docs] def convert_prediction_to_x( self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor ) -> Tensor: """ Converts the neural network prediction into the denoised data `x`, depending on the prediction type configured for the model. Parameters ---------- pred : Tensor The output prediction from the neural network, typically representing noise, velocity, or a transformation of the clean signal. z : Tensor The noisy latent variable `z` to be denoised. alpha_t : Tensor The noise schedule's scaling factor for the clean signal at time `t`. sigma_t : Tensor The standard deviation of the noise at time `t`. log_snr_t : Tensor The log signal-to-noise ratio at time `t`. Returns ------- Tensor The reconstructed clean signal `x` from the model prediction. """ if self._prediction_type == "velocity": return alpha_t * z - sigma_t * pred elif self._prediction_type == "noise": return (z - sigma_t * pred) / alpha_t elif self._prediction_type == "F": sigma_data = getattr(self.noise_schedule, "sigma_data", 1.0) x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2) x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) return x1 * z + x2 * pred elif self._prediction_type == "x": return pred elif self._prediction_type == "score": return (z + sigma_t**2 * pred) / alpha_t raise ValueError(f"Unknown prediction type {self._prediction_type}.")
[docs] def velocity( self, xz: Tensor, time: float | Tensor, stochastic_solver: bool, conditions: Tensor = None, training: bool = False, ) -> Tensor: """ Computes the velocity (i.e., time derivative) of the target or latent variable `xz` for either a stochastic differential equation (SDE) or ordinary differential equation (ODE). Parameters ---------- xz : Tensor The current state of the latent variable `z`, typically of shape (..., D), where D is the dimensionality of the latent space. time : float or Tensor Scalar or tensor representing the time (or noise level) at which the velocity should be computed. Will be broadcasted to xz. stochastic_solver : bool If True, computes the velocity for the stochastic formulation (SDE). If False, uses the deterministic formulation (ODE). conditions : Tensor, optional Optional conditional inputs to the network, such as conditioning variables or encoder outputs. Shape must be broadcastable with `xz`. Default is None. training : bool, optional Whether the model is in training mode. Affects behavior of dropout, batch norm, or other stochastic layers. Default is False. Returns ------- Tensor The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ # calculate the current noise level and transform into correct shape log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) if conditions is None: xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t)], axis=-1) else: xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1) pred = self.output_projector(self.subnet(xtc, training=training), training=training) x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) score = (alpha_t * x_pred - xz) / ops.square(sigma_t) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) if stochastic_solver: # for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW out = f - g_squared * score else: # for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt out = f - 0.5 * g_squared * score return out
[docs] def diffusion_term( self, xz: Tensor, time: float | Tensor, training: bool = False, ) -> Tensor: """ Compute the diffusion term (standard deviation of the noise) at a given time. Parameters ---------- xz : Tensor Input tensor of shape (..., D), typically representing the target or latent variables at given time. time : float or Tensor The diffusion time step(s). Can be a scalar or a tensor broadcastable to the shape of `xz`. training : bool, optional Whether to use the training noise schedule (default is False). Returns ------- Tensor The diffusion term tensor with shape matching `xz` except for the last dimension, which is set to 1. """ log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t) return ops.sqrt(g_squared)
def _velocity_trace( self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False, ) -> (Tensor, Tensor): def f(x): return self.velocity(x, time=time, stochastic_solver=False, conditions=conditions, training=training) v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True) return v, ops.expand_dims(trace, axis=-1) def _transform_log_snr(self, log_snr: Tensor) -> Tensor: """Transform the log_snr to the range [-1, 1] for the diffusion process.""" log_snr_min = self.noise_schedule.log_snr_min log_snr_max = self.noise_schedule.log_snr_max normalized_snr = (log_snr - log_snr_min) / (log_snr_max - log_snr_min) scaled_value = 2 * normalized_snr - 1 return scaled_value def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for forward integration.") if density: def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) return {"xz": v, "trace": trace} state = { "xz": x, "trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)), } state = integrate( deltas, state, **integrate_kwargs, ) z = state["xz"] log_density = self.base_distribution.log_prob(z) + ops.squeeze(state["trace"], axis=-1) return z, log_density def deltas(time, xz): return { "xz": self.velocity(xz, time=time, stochastic_solver=False, conditions=conditions, training=training) } state = {"xz": x} state = integrate( deltas, state, **integrate_kwargs, ) z = state["xz"] return z def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if density: if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) return {"xz": v, "trace": trace} state = { "xz": z, "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), } state = integrate(deltas, state, **integrate_kwargs) x = state["xz"] log_density = self.base_distribution.log_prob(z) - ops.squeeze(state["trace"], axis=-1) return x, log_density state = {"xz": z} if integrate_kwargs["method"] == "euler_maruyama": def deltas(time, xz): return { "xz": self.velocity(xz, time=time, stochastic_solver=True, conditions=conditions, training=training) } def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, state=state, seed=self.seed_generator, **integrate_kwargs, ) else: def deltas(time, xz): return { "xz": self.velocity( xz, time=time, stochastic_solver=False, conditions=conditions, training=training ) } state = integrate( deltas, state, **integrate_kwargs, ) x = state["xz"] return x
[docs] def compute_metrics( self, x: Tensor | Sequence[Tensor, ...], conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training", ) -> dict[str, Tensor]: training = stage == "training" # use same noise schedule for training and validation to keep them comparable noise_schedule_training_stage = stage == "training" or stage == "validation" if not self.built: xz_shape = ops.shape(x) conditions_shape = None if conditions is None else ops.shape(conditions) self.build(xz_shape, conditions_shape) # sample training diffusion time as low discrepancy sequence to decrease variance u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x), seed=self.seed_generator) i = ops.arange(0, ops.shape(x)[0], dtype=ops.dtype(x)) t = (u0 + i / ops.cast(ops.shape(x)[0], dtype=ops.dtype(x))) % 1 # calculate the noise level log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t, training=noise_schedule_training_stage), x) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) weights_for_snr = self.noise_schedule.get_weights_for_snr(log_snr_t=log_snr_t) # generate noise vector eps_t = keras.random.normal(ops.shape(x), dtype=ops.dtype(x), seed=self.seed_generator) # diffuse x diffused_x = alpha_t * x + sigma_t * eps_t # calculate output of the network if conditions is None: xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1) else: xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1) pred = self.output_projector(self.subnet(xtc, training=training), training=training) x_pred = self.convert_prediction_to_x( pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t ) if self._loss_type == "noise": # convert x to epsilon prediction noise_pred = (diffused_x - alpha_t * x_pred) / sigma_t loss = weights_for_snr * ops.mean((noise_pred - eps_t) ** 2, axis=-1) elif self._loss_type == "velocity": # convert x to velocity prediction velocity_pred = (alpha_t * diffused_x - x_pred) / sigma_t v_t = alpha_t * eps_t - sigma_t * x loss = weights_for_snr * ops.mean((velocity_pred - v_t) ** 2, axis=-1) elif self._loss_type == "F": # convert x to F prediction sigma_data = self.noise_schedule.sigma_data if hasattr(self.noise_schedule, "sigma_data") else 1.0 x1 = ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) / (ops.exp(-log_snr_t / 2) * sigma_data) x2 = (sigma_data * alpha_t) / (ops.exp(-log_snr_t / 2) * ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)) f_pred = x1 * x_pred - x2 * diffused_x f_t = x1 * x - x2 * diffused_x loss = weights_for_snr * ops.mean((f_pred - f_t) ** 2, axis=-1) else: raise ValueError(f"Unknown loss type: {self._loss_type}") loss = weighted_mean(loss, sample_weight) base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss}