Source code for bayesflow.utils.dispatch.find_inference_network
from functools import singledispatch
import keras
[docs]
@singledispatch
def find_inference_network(arg, *args, **kwargs):
raise TypeError(f"Cannot infer inference network from {arg!r}.")
@find_inference_network.register
def _(name: str, *args, **kwargs):
match name.lower():
case "coupling_flow":
from bayesflow.networks import CouplingFlow
return CouplingFlow(*args, **kwargs)
case "flow_matching":
from bayesflow.networks import FlowMatching
return FlowMatching(*args, **kwargs)
case "consistency_model":
from bayesflow.networks import ConsistencyModel
return ConsistencyModel(*args, **kwargs)
case unknown_network:
raise ValueError(f"Unknown inference network: '{unknown_network}'")
@find_inference_network.register
def _(layer: keras.Layer, *args, **kwargs):
return layer
@find_inference_network.register
def _(model: keras.Model, *args, **kwargs):
return model