Source code for bayesflow.utils.dispatch.find_summary_network
from functools import singledispatch
import keras
[docs]
@singledispatch
def find_summary_network(arg, *args, **kwargs):
raise TypeError(f"Cannot infer inference network from {arg!r}.")
@find_summary_network.register
def _(name: str, *args, **kwargs):
match name.lower():
case "deep_set":
from bayesflow.networks import DeepSet
return DeepSet(*args, **kwargs)
case "set_transformer":
from bayesflow.networks import SetTransformer
return SetTransformer(*args, **kwargs)
case "fusion_transformer":
from bayesflow.networks import FusionTransformer
return FusionTransformer(*args, **kwargs)
case "time_series_transformer":
from bayesflow.networks import TimeSeriesTransformer
return TimeSeriesTransformer(*args, **kwargs)
case "time_series_network":
from bayesflow.networks import TimeSeriesNetwork
return TimeSeriesNetwork(*args, **kwargs)
case unknown_network:
raise ValueError(f"Unknown summary network: '{unknown_network}'")
@find_summary_network.register
def _(layer: keras.Layer, *args, **kwargs):
return layer
@find_summary_network.register
def _(model: keras.Model, *args, **kwargs):
return model