Source code for bayesflow.links.ordered_quantiles

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.utils import keras_kwargs, logging

from collections.abc import Sequence

from .ordered import Ordered


[docs] @serializable(package="links.ordered_quantiles") class OrderedQuantiles(Ordered): """Activation function to link to monotonously increasing quantile estimates.""" def __init__(self, q: Sequence[float] = None, axis: int = None, **kwargs): super().__init__(axis, None, **keras_kwargs(kwargs)) self.q = q self.config = { "q": q, "axis": axis, }
[docs] def get_config(self): base_config = super().get_config() return base_config | self.config
[docs] def build(self, input_shape): if self.axis is None and 1 < len(input_shape) <= 3: self.axis = -2 elif self.axis is None: raise AssertionError( f"Cannot resolve which axis should be ordered automatically from input shape {input_shape}." ) num_quantile_levels = input_shape[self.axis] if self.q is None: # choose the middle of the specified axis as anchor index self.anchor_index = num_quantile_levels // 2 logging.info( f"`OrderedQuantiles` was not provided with argument `q`. Using index {self.anchor_index} as anchor." ) else: # choose quantile level closest to median as anchor index self.anchor_index = keras.ops.argmin(keras.ops.abs(keras.ops.convert_to_tensor(self.q) - 0.5)) if len(self.q) != num_quantile_levels: raise RuntimeError( f"Length of `q` does not coincide with input shape: len(q)={len(self.q)}, " f"position {self.axis} of shape={input_shape}" ) if self.anchor_index in [0, -1, num_quantile_levels - 1]: raise RuntimeError( f"The link function `OrderedQuantiles` expects at least 3 quantile levels, " f"but only {num_quantile_levels} were given." ) super().build(input_shape)