Source code for bayesflow.experimental.diffusion_model.dispatch
from functools import singledispatch
from .schedules.noise_schedule import NoiseSchedule
[docs]
@singledispatch
def find_noise_schedule(arg, *args, **kwargs):
raise TypeError(f"Not a noise schedule: {arg!r}. Please pass an object of type 'NoiseSchedule'.")
@find_noise_schedule.register
def _(noise_schedule: NoiseSchedule):
return noise_schedule
@find_noise_schedule.register
def _(name: str, *args, **kwargs):
match name.lower():
case "cosine":
from .schedules import CosineNoiseSchedule
return CosineNoiseSchedule(*args, **kwargs)
case "edm":
from .schedules import EDMNoiseSchedule
return EDMNoiseSchedule(*args, **kwargs)
case other:
raise ValueError(f"Unsupported noise schedule name: '{other}'.")
@find_noise_schedule.register
def _(cls: type, *args, **kwargs):
if issubclass(cls, NoiseSchedule):
return cls(*args, **kwargs)
raise TypeError(f"Expected subclass of NoiseSchedule, got {cls}")