Source code for bayesflow.experimental.diffusion_model.schedules.cosine_noise_schedule

import math
from typing import Literal

from keras import ops

from bayesflow.types import Tensor
from bayesflow.utils.serialization import deserialize, serializable

from .noise_schedule import NoiseSchedule


# disable module check, use potential module after moving from experimental
[docs] @serializable("bayesflow.networks", disable_module_check=True) class CosineNoiseSchedule(NoiseSchedule): """Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1]. A cosine schedule is a popular technique for controlling how the variance (noise level) or learning rate evolves during the training of diffusion models. It was proposed as an improvement over the original linear beta schedule in [2] [1] Dhariwal, P., & Nichol, A. (2021). Diffusion models beat gans on image synthesis. Advances in Neural Information Processing Systems, 34, 8780-8794. [2] Ho, J., Jain, A., & Abbeel, P. (2020). Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33, 6840-6851. """ def __init__( self, min_log_snr: float = -15, max_log_snr: float = 15, shift: float = 0.0, weighting: Literal["sigmoid", "likelihood_weighting"] = "sigmoid", ): """ Initialize the cosine noise schedule. Parameters ---------- min_log_snr : float, optional The minimum log signal-to-noise ratio (lambda). Default is -15. max_log_snr : float, optional The maximum log signal-to-noise ratio (lambda). Default is 15. shift : float, optional Shift the log signal-to-noise ratio (lambda) by this amount. Default is 0.0. For images, use shift = log(base_resolution / d), where d is the used resolution of the image. weighting : Literal["sigmoid", "likelihood_weighting"], optional The type of weighting function to use for the noise schedule. Default is "sigmoid". """ super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting) self._shift = shift self._weighting = weighting self.log_snr_min = min_log_snr self.log_snr_max = max_log_snr self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True) self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True) def _truncated_t(self, t: Tensor) -> Tensor: return self._t_min + (self._t_max - self._t_min) * t
[docs] def get_log_snr(self, t: Tensor | float, training: bool) -> Tensor: """Get the log signal-to-noise ratio (lambda) for a given diffusion time.""" t_trunc = self._truncated_t(t) return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._shift
[docs] def get_t_from_log_snr(self, log_snr_t: Tensor | float, training: bool) -> Tensor: """Get the diffusion time (t) from the log signal-to-noise ratio (lambda).""" # SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2)) return 2 / math.pi * ops.arctan(ops.exp((2 * self._shift - log_snr_t) * 0.5))
[docs] def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor: """Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE.""" t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training) # Compute the truncated time t_trunc t_trunc = self._truncated_t(t) dsnr_dx = -(2 * math.pi) / ops.sin(math.pi * t_trunc) # Using the chain rule on f(t) = log(1 + e^(-snr(t))): # f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt dsnr_dt = dsnr_dx * (self._t_max - self._t_min) factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t)) return -factor * dsnr_dt
[docs] def get_config(self): config = { "min_log_snr": self.log_snr_min, "max_log_snr": self.log_snr_max, "shift": self._shift, "weighting": self._weighting, } return config
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))