Source code for bayesflow.utils.dispatch.find_recurrent_net
import keras
from functools import singledispatch
[docs]
@singledispatch
def find_recurrent_net(arg, *args, **kwargs):
raise TypeError(f"Cannot infer network from {arg!r}.")
@find_recurrent_net.register
def _(name: str, *args, **kwargs):
match name.lower():
case "lstm":
constructor = keras.layers.LSTM
case "gru":
constructor = keras.layers.GRU
case other:
raise ValueError(f"Unsupported network name: '{other}'.")
return constructor
@find_recurrent_net.register
def _(network: keras.Layer, *args, **kwargs):
return network