Source code for bayesflow.utils.dispatch.find_distribution
from functools import singledispatch
import keras
[docs]
@singledispatch
def find_distribution(arg, **kwargs):
raise TypeError(f"Cannot infer distribution from {arg!r}.")
@find_distribution.register
def _(name: str, *args, **kwargs):
match name.lower():
case "normal":
from bayesflow.distributions import DiagonalNormal
distribution = DiagonalNormal(*args, **kwargs)
case "student" | "student-t" | "student_t":
from bayesflow.distributions import DiagonalStudentT
distribution = DiagonalStudentT(*args, **kwargs)
case "mixture":
raise ValueError(
"Mixture distributions need to be explicitly defined as bf.distributions.Mixture(...) "
"and passed to the constructor."
)
case "none":
distribution = None
case other:
raise ValueError(f"Unsupported distribution name '{other}'.")
return distribution
@find_distribution.register
def _(none: None, *args, **kwargs):
return None
@find_distribution.register
def _(distribution: keras.Layer, *args, **kwargs):
return distribution