Source code for bayesflow.experimental.diffusion_model.schedules.edm_noise_schedule

import math

from keras import ops

from bayesflow.types import Tensor
from bayesflow.utils.serialization import deserialize, serializable

from .noise_schedule import NoiseSchedule


# disable module check, use potential module after moving from experimental
[docs] @serializable("bayesflow.networks", disable_module_check=True) class EDMNoiseSchedule(NoiseSchedule): """EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1]. This should be used with the F-prediction type in the diffusion model. [1] Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the design space of diffusion-based generative models. Advances in Neural Information Processing Systems, 35, 26565-26577. """ def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max: float = 80.0): """ Initialize the EDM noise schedule. Parameters ---------- sigma_data : float, optional The standard deviation of the output distribution. Input of the network is scaled by this factor and the weighting function is scaled by this factor as well. Default is 1.0. sigma_min : float, optional The minimum noise level. Only relevant for sampling. Default is 1e-4. sigma_max : float, optional The maximum noise level. Only relevant for sampling. Default is 80.0. """ super().__init__(name="edm_noise_schedule", variance_type="preserving") self.sigma_data = sigma_data # training settings self.p_mean = -1.2 self.p_std = 1.2 # sampling settings self.sigma_max = sigma_max self.sigma_min = sigma_min self.rho = 7 # convert EDM parameters to signal-to-noise ratio formulation self.log_snr_min = -2 * ops.log(sigma_max) self.log_snr_max = -2 * ops.log(sigma_min) # t is not truncated for EDM by definition of the sampling schedule # training bounds should be set to avoid numerical issues self._log_snr_min_training = self.log_snr_min - 1 # one is never sampler during training self._log_snr_max_training = self.log_snr_max + 1 # 0 is almost surely never sampled during training
[docs] def get_log_snr(self, t: float | Tensor, training: bool) -> Tensor: """Get the log signal-to-noise ratio (lambda) for a given diffusion time.""" if training: # SNR = -dist.icdf(t_trunc) # negative seems to be wrong in the Kingma paper loc = -2 * self.p_mean scale = 2 * self.p_std snr = loc + scale * ops.erfinv(2 * t - 1) * math.sqrt(2) snr = ops.clip(snr, x_min=self._log_snr_min_training, x_max=self._log_snr_max_training) else: sigma_min_rho = self.sigma_min ** (1 / self.rho) sigma_max_rho = self.sigma_max ** (1 / self.rho) snr = -2 * self.rho * ops.log(sigma_max_rho + (1 - t) * (sigma_min_rho - sigma_max_rho)) return snr
[docs] def get_t_from_log_snr(self, log_snr_t: float | Tensor, training: bool) -> Tensor: """Get the diffusion time (t) from the log signal-to-noise ratio (lambda).""" if training: # SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr) # negative seems to be wrong in the Kingma paper loc = -2 * self.p_mean scale = 2 * self.p_std x = log_snr_t t = 0.5 * (1 + ops.erf((x - loc) / (scale * math.sqrt(2.0)))) else: # sampling # SNR = -2 * rho * log(sigma_max ** (1/rho) + (1 - t) * (sigma_min ** (1/rho) - sigma_max ** (1/rho))) # => t = 1 - ((exp(-snr/(2*rho)) - sigma_max ** (1/rho)) / (sigma_min ** (1/rho) - sigma_max ** (1/rho))) sigma_min_rho = self.sigma_min ** (1 / self.rho) sigma_max_rho = self.sigma_max ** (1 / self.rho) t = 1 - ((ops.exp(-log_snr_t / (2 * self.rho)) - sigma_max_rho) / (sigma_min_rho - sigma_max_rho)) return t
[docs] def derivative_log_snr(self, log_snr_t: Tensor, training: bool = False) -> Tensor: """Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE.""" if training: raise NotImplementedError("Derivative of log SNR is not implemented for training mode.") # sampling mode t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training) # SNR = -2*rho*log(s_max + (1 - x)*(s_min - s_max)) s_max = self.sigma_max ** (1 / self.rho) s_min = self.sigma_min ** (1 / self.rho) u = s_max + (1 - t) * (s_min - s_max) # d/dx snr = 2*rho*(s_min - s_max) / u dsnr_dx = 2 * self.rho * (s_min - s_max) / u # Using the chain rule on f(t) = log(1 + e^(-snr(t))): # f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t)) return -factor * dsnr_dx
[docs] def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor: """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).""" # for F-loss: w = (ops.exp(-log_snr_t) + sigma_data^2) / (ops.exp(-log_snr_t)*sigma_data^2) return 1 + ops.exp(-log_snr_t) / ops.square(self.sigma_data)
[docs] def get_config(self): config = {"sigma_data": self.sigma_data, "sigma_min": self.sigma_min, "sigma_max": self.sigma_max} return config
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))