Source code for bayesflow.utils.dispatch.find_pooling
import keras
from functools import singledispatch
[docs]
@singledispatch
def find_pooling(arg, *args, **kwargs):
raise TypeError(f"Cannot infer pooling from {arg!r}.")
@find_pooling.register
def _(name: str, *args, **kwargs):
match name.lower():
case "mean" | "avg" | "average":
pooling = keras.layers.Lambda(lambda inp: keras.ops.mean(inp, axis=-2))
case "max":
pooling = keras.layers.Lambda(lambda inp: keras.ops.max(inp, axis=-2))
case "min":
pooling = keras.layers.Lambda(lambda inp: keras.ops.min(inp, axis=-2))
case "learnable" | "pma" | "attention":
from bayesflow.networks.transformers.pma import PoolingByMultiHeadAttention
pooling = PoolingByMultiHeadAttention(*args, **kwargs)
case other:
raise ValueError(f"Unsupported pooling name: '{other}'.")
return pooling
@find_pooling.register
def _(constructor: type, *args, **kwargs):
return constructor(*args, **kwargs)
@find_pooling.register
def _(pooling: keras.Layer, *args, **kwargs):
return pooling