Source code for bayesflow.utils.dispatch.find_distribution
from functools import singledispatch
[docs]
@singledispatch
def find_distribution(arg, **kwargs):
raise TypeError(f"Cannot infer distribution from {arg!r}.")
@find_distribution.register
def _(name: str, *args, **kwargs):
match name.lower():
case "normal":
from bayesflow.distributions import DiagonalNormal
distribution = DiagonalNormal(*args, **kwargs)
case "none":
distribution = None
case other:
raise ValueError(f"Unsupported distribution name '{other}'.")
return distribution
@find_distribution.register
def _(none: None, *args, **kwargs):
return None