Source code for bayesflow.links.ordered
import keras
from keras.saving import register_keras_serializable as serializable
from bayesflow.utils import keras_kwargs
[docs]
@serializable(package="links.ordered")
class Ordered(keras.Layer):
"""Activation function to link to a tensor which is monotonously increasing along a specified axis."""
def __init__(self, axis: int, anchor_index: int, **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.axis = axis
self.anchor_index = anchor_index
self.group_indices = None
self.config = {"axis": axis, "anchor_index": anchor_index, **kwargs}
[docs]
def get_config(self):
base_config = super().get_config()
return base_config | self.config
[docs]
def build(self, input_shape):
super().build(input_shape)
if self.anchor_index % input_shape[self.axis] == 0 or self.anchor_index == -1:
raise RuntimeError("Anchor should not be first or last index.")
self.group_indices = dict(
below=list(range(0, self.anchor_index)),
above=list(range(self.anchor_index + 1, input_shape[self.axis])),
)
[docs]
def call(self, inputs):
# Divide in anchor, below and above
below_inputs = keras.ops.take(inputs, self.group_indices["below"], axis=self.axis)
anchor_input = keras.ops.take(inputs, self.anchor_index, axis=self.axis)
anchor_input = keras.ops.expand_dims(anchor_input, axis=self.axis)
above_inputs = keras.ops.take(inputs, self.group_indices["above"], axis=self.axis)
# Apply softplus for positivity and cumulate to ensure ordered quantiles
below = keras.activations.softplus(below_inputs)
above = keras.activations.softplus(above_inputs)
below = anchor_input - keras.ops.flip(keras.ops.cumsum(below, axis=self.axis), self.axis)
above = anchor_input + keras.ops.cumsum(above, axis=self.axis)
# Concatenate and reshape back
x = keras.ops.concatenate([below, anchor_input, above], self.axis)
return x
[docs]
def compute_output_shape(self, input_shape):
return input_shape