Source code for bayesflow.links.ordered_quantiles
import keras
from keras.saving import register_keras_serializable as serializable
from bayesflow.utils import keras_kwargs, logging
from collections.abc import Sequence
from .ordered import Ordered
[docs]
@serializable(package="links.ordered_quantiles")
class OrderedQuantiles(Ordered):
"""Activation function to link to monotonously increasing quantile estimates."""
def __init__(self, q: Sequence[float] = None, axis: int = None, **kwargs):
super().__init__(axis, None, **keras_kwargs(kwargs))
self.q = q
self.config = {
"q": q,
"axis": axis,
}
[docs]
def get_config(self):
base_config = super().get_config()
return base_config | self.config
[docs]
def build(self, input_shape):
if self.axis is None and 1 < len(input_shape) <= 3:
self.axis = -2
elif self.axis is None:
raise AssertionError(
f"Cannot resolve which axis should be ordered automatically from input shape {input_shape}."
)
num_quantile_levels = input_shape[self.axis]
if self.q is None:
# choose the middle of the specified axis as anchor index
self.anchor_index = num_quantile_levels // 2
logging.info(
f"`OrderedQuantiles` was not provided with argument `q`. Using index {self.anchor_index} as anchor."
)
else:
# choose quantile level closest to median as anchor index
self.anchor_index = keras.ops.argmin(keras.ops.abs(keras.ops.convert_to_tensor(self.q) - 0.5))
if len(self.q) != num_quantile_levels:
raise RuntimeError(
f"Length of `q` does not coincide with input shape: len(q)={len(self.q)}, "
f"position {self.axis} of shape={input_shape}"
)
if self.anchor_index in [0, -1, num_quantile_levels - 1]:
raise RuntimeError(
f"The link function `OrderedQuantiles` expects at least 3 quantile levels, "
f"but only {num_quantile_levels} were given."
)
super().build(input_shape)