Source code for bayesflow.utils.dispatch.find_network
from functools import singledispatch
[docs]
@singledispatch
def find_network(arg, *args, **kwargs):
raise TypeError(f"Cannot infer network from {arg!r}.")
@find_network.register
def _(name: str, *args, **kwargs):
match name.lower():
case "mlp" | "default":
from bayesflow.networks import MLP
network = MLP(*args, **kwargs)
case other:
raise ValueError(f"Unsupported network name: '{other}'.")
return network
@find_network.register
def _(cls: type, *args, **kwargs):
# Instantiate class with the given arguments
network = cls(*args, **kwargs)
return network