Source code for bayesflow.utils.dispatch.find_permutation
import keras
from functools import singledispatch
[docs]
@singledispatch
def find_permutation(arg, *args, **kwargs):
raise TypeError(f"Cannot infer permutation from {arg!r}.")
@find_permutation.register
def _(name: str, *args, **kwargs):
match name.lower():
case "random":
from bayesflow.networks.coupling_flow.permutations import RandomPermutation
return RandomPermutation(*args, **kwargs)
case "swap":
from bayesflow.networks.coupling_flow.permutations import Swap
return Swap(*args, **kwargs)
case "learnable" | "orthogonal":
from bayesflow.networks.coupling_flow.permutations import OrthogonalPermutation
return OrthogonalPermutation(*args, **kwargs)
@find_permutation.register
def _(permutation: keras.Layer, *args, **kwargs):
return permutation
@find_permutation.register
def _(none: None, *args, **kwargs):
return None