Source code for bayesflow.experimental.diffusion_model.schedules.noise_schedule

from abc import ABC, abstractmethod
from typing import Literal

from keras import ops

from bayesflow.types import Tensor
from bayesflow.utils.serialization import deserialize, serializable


# disable module check, use potential module after moving from experimental
[docs] @serializable("bayesflow.networks", disable_module_check=True) class NoiseSchedule(ABC): r"""Noise schedule for diffusion models. We follow the notation from [1]. The diffusion process is defined by a noise schedule, which determines how the noise level changes over time. We define the noise schedule as a function of the log signal-to-noise ratio (lambda), which can be interchangeably used with the diffusion time (t). The noise process is defined as: z = alpha(t) * x + sigma(t) * e, where e ~ N(0, I). The schedule is defined as: \lambda(t) = \log \sigma^2(t) - \log \alpha^2(t). We can also define a weighting function for each noise level for the loss function. Often the noise schedule is the same for the forward and reverse process, but this is not necessary and can be changed via the training flag. [1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data Augmentation: Kingma et al. (2023) """ def __init__( self, name: str, variance_type: Literal["preserving", "exploding"], weighting: Literal["sigmoid", "likelihood_weighting"] = None, ): """ Initialize the noise schedule with given variance and weighting strategy. Parameters ---------- name : str The name of the noise schedule. variance_type : Literal["preserving", "exploding"] If the variance of noise added to the data should be preserved over time, use "preserving". If the variance of noise added to the data should increase over time, use "exploding". Default is "preserving". weighting : Literal["sigmoid", "likelihood_weighting"], optional The type of weighting function to use for the noise schedule. Default is None, which means no weighting is applied. """ self.name = name self._variance_type = variance_type self.log_snr_min = None # should be set in the subclasses self.log_snr_max = None # should be set in the subclasses self._weighting = weighting
[docs] @abstractmethod def get_log_snr(self, t: float | Tensor, training: bool) -> Tensor: """Get the log signal-to-noise ratio (lambda) for a given diffusion time.""" pass
[docs] @abstractmethod def get_t_from_log_snr(self, log_snr_t: float | Tensor, training: bool) -> Tensor: """Get the diffusion time (t) from the log signal-to-noise ratio (lambda).""" pass
[docs] @abstractmethod def derivative_log_snr(self, log_snr_t: float | Tensor, training: bool) -> Tensor: r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE.""" pass
[docs] def get_drift_diffusion( self, log_snr_t: Tensor, x: Tensor = None, training: bool = False ) -> Tensor | tuple[Tensor, Tensor]: r"""Compute the drift and optionally the squared diffusion term for the reverse SDE. It can be derived from the derivative of the schedule: math:: \beta(t) = d/dt \log(1 + e^{-snr(t)}) f(z, t) = -0.5 * \beta(t) * z g(t)^2 = \beta(t) The corresponding differential equations are:: SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt For a variance exploding schedule, one should set f(z, t) = 0. """ beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training) if x is None: # return g^2 only return beta if self._variance_type == "preserving": f = -0.5 * beta * x elif self._variance_type == "exploding": f = ops.zeros_like(beta) else: raise ValueError(f"Unknown variance type: {self._variance_type}") return f, beta
[docs] def get_alpha_sigma(self, log_snr_t: Tensor) -> tuple[Tensor, Tensor]: """Get alpha and sigma for a given log signal-to-noise ratio (lambda). Default is a variance preserving schedule: alpha(t) = sqrt(sigmoid(log_snr_t)) sigma(t) = sqrt(sigmoid(-log_snr_t)) For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda) """ if self._variance_type == "preserving": # variance preserving schedule alpha_t = ops.sqrt(ops.sigmoid(log_snr_t)) sigma_t = ops.sqrt(ops.sigmoid(-log_snr_t)) elif self._variance_type == "exploding": # variance exploding schedule alpha_t = ops.ones_like(log_snr_t) sigma_t = ops.sqrt(ops.exp(-log_snr_t)) else: raise TypeError(f"Unknown variance type: {self._variance_type}") return alpha_t, sigma_t
[docs] def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor: """ Compute loss weights based on log signal-to-noise ratio (log-SNR). This method returns a tensor of weights used for loss re-weighting in diffusion models, depending on the selected strategy. If no weighting is specified, uniform weights (ones) are returned. Supported weighting strategies: - "sigmoid": Based on Kingma et al. (2023), uses a sigmoid of shifted log-SNR. - "likelihood_weighting": Based on Song et al. (2021), uses ratio of diffusion drift to squared noise scale. Parameters ---------- log_snr_t : Tensor A tensor containing the log signal-to-noise ratio values. Returns ------- Tensor A tensor of weights corresponding to each log-SNR value. Raises ------ TypeError If the weighting strategy specified in `self._weighting` is unknown. """ if self._weighting is None: return ops.ones_like(log_snr_t) elif self._weighting == "sigmoid": # sigmoid weighting based on Kingma et al. (2023) return ops.sigmoid(-log_snr_t + 2) elif self._weighting == "likelihood_weighting": # likelihood weighting based on Song et al. (2021) g_squared = self.get_drift_diffusion(log_snr_t) _, sigma_t = self.get_alpha_sigma(log_snr_t) return g_squared / ops.square(sigma_t) else: raise TypeError(f"Unknown weighting type: {self._weighting}")
[docs] def get_config(self): return {"name": self.name, "variance_type": self._variance_type, "weighting": self._weighting}
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))
[docs] def validate(self): """Validate the noise schedule.""" if self.log_snr_min >= self.log_snr_max: raise ValueError("min_log_snr must be less than max_log_snr.") # Validate log SNR values and corresponding time mappings for both training and inference for training in (True, False): if not ops.isfinite(self.get_log_snr(0.0, training=training)): raise ValueError(f"log_snr(0.0) must be finite (training={training})") if not ops.isfinite(self.get_log_snr(1.0, training=training)): raise ValueError(f"log_snr(1.0) must be finite (training={training})") if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_max, training=training)): raise ValueError(f"t(log_snr_max) must be finite (training={training})") if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_min, training=training)): raise ValueError(f"t(log_snr_min) must be finite (training={training})") # Validate log SNR derivatives at the boundaries for boundary, name in [(self.log_snr_max, "log_snr_max (t=0)"), (self.log_snr_min, "log_snr_min (t=1)")]: if not ops.isfinite(self.derivative_log_snr(boundary, training=False)): raise ValueError(f"derivative_log_snr at {name} must be finite.")