EDMNoiseSchedule#

class bayesflow.experimental.diffusion_model.EDMNoiseSchedule(sigma_data: float = 1.0, sigma_min: float = 0.0001, sigma_max: float = 80.0)[source]#

Bases: NoiseSchedule

EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1]. This should be used with the F-prediction type in the diffusion model.

[1] Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the design space of diffusion-based generative models. Advances in Neural Information Processing Systems, 35, 26565-26577.

Initialize the EDM noise schedule.

Parameters:
sigma_datafloat, optional

The standard deviation of the output distribution. Input of the network is scaled by this factor and the weighting function is scaled by this factor as well. Default is 1.0.

sigma_minfloat, optional

The minimum noise level. Only relevant for sampling. Default is 1e-4.

sigma_maxfloat, optional

The maximum noise level. Only relevant for sampling. Default is 80.0.

get_log_snr(t: float | Tensor, training: bool) Tensor[source]#

Get the log signal-to-noise ratio (lambda) for a given diffusion time.

get_t_from_log_snr(log_snr_t: float | Tensor, training: bool) Tensor[source]#

Get the diffusion time (t) from the log signal-to-noise ratio (lambda).

derivative_log_snr(log_snr_t: Tensor, training: bool = False) Tensor[source]#

Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE.

get_weights_for_snr(log_snr_t: Tensor) Tensor[source]#

Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).

get_config()[source]#
classmethod from_config(config, custom_objects=None)[source]#
get_alpha_sigma(log_snr_t: Tensor) tuple[Tensor, Tensor]#

Get alpha and sigma for a given log signal-to-noise ratio (lambda).

Default is a variance preserving schedule:

alpha(t) = sqrt(sigmoid(log_snr_t)) sigma(t) = sqrt(sigmoid(-log_snr_t))

For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)

get_drift_diffusion(log_snr_t: Tensor, x: Tensor = None, training: bool = False) Tensor | tuple[Tensor, Tensor]#

Compute the drift and optionally the squared diffusion term for the reverse SDE. It can be derived from the derivative of the schedule:

math::

beta(t) = d/dt log(1 + e^{-snr(t)})

f(z, t) = -0.5 * beta(t) * z

g(t)^2 = beta(t)

The corresponding differential equations are:

SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt

For a variance exploding schedule, one should set f(z, t) = 0.

validate()#

Validate the noise schedule.