from keras.saving import (
register_keras_serializable as serializable,
)
import numpy as np
from bayesflow.utils.numpy_utils import (
inverse_sigmoid,
inverse_softplus,
sigmoid,
softplus,
)
from .elementwise_transform import ElementwiseTransform
[docs]
@serializable(package="bayesflow.adapters")
class Constrain(ElementwiseTransform):
"""
Constrains neural network predictions of a data variable to specified bounds.
Parameters
----------
* : str
String containing the name of the data variable to be transformed e.g. "sigma". See examples below.
lower : int or float or np.darray, optional
Lower bound for named data variable.
upper : int or float or np.darray, optional
Upper bound for named data variable.
method : str, optional
Method by which to shrink the network predictions space to specified bounds. Choose from
- Double bounded methods: sigmoid, expit, (default = sigmoid)
- Lower bound only methods: softplus, exp, (default = softplus)
- Upper bound only methods: softplus, exp, (default = softplus)
inclusive : {'both', 'lower', 'upper', 'none'}, optional
Indicates which bounds are inclusive (or exclusive).
- "both" (default): Both lower and upper bounds are inclusive.
- "lower": Lower bound is inclusive, upper bound is exclusive.
- "upper": Lower bound is exclusive, upper bound is inclusive.
- "none": Both lower and upper bounds are exclusive.
epsilon : float, optional
Small value to ensure inclusive bounds are not violated.
Current default is 1e-15 as this ensures finite outcomes
with the default transformations applied to data exactly at the boundaries.
Examples
--------
1) Let sigma be the standard deviation of a normal distribution,
then sigma should always be greater than zero.
>>> adapter = (
bf.Adapter()
.constrain("sigma", lower=0)
)
2) Suppose p is the parameter for a binomial distribution where p must be in
[0,1] then we would constrain the neural network to estimate p in the following way.
>>> import bayesflow as bf
>>> adapter = bf.Adapter()
>>> adapter.constrain("p", lower=0, upper=1, method="sigmoid", inclusive="both")
"""
def __init__(
self,
*,
lower: int | float | np.ndarray = None,
upper: int | float | np.ndarray = None,
method: str = "default",
inclusive: str = "both",
epsilon: float = 1e-15,
):
super().__init__()
if lower is None and upper is None:
raise ValueError("At least one of 'lower' or 'upper' must be provided.")
if lower is not None and upper is not None:
# double bounded case
if np.any(lower >= upper):
raise ValueError("The lower bound must be strictly less than the upper bound.")
match method:
case "default" | "sigmoid" | "expit" | "logit":
def constrain(x):
return (upper - lower) * sigmoid(x) + lower
def unconstrain(x):
return inverse_sigmoid((x - lower) / (upper - lower))
case str() as name:
raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.")
case other:
raise TypeError(f"Expected a method name, got {other!r}.")
elif lower is not None:
# lower bounded case
match method:
case "default" | "softplus":
def constrain(x):
return softplus(x) + lower
def unconstrain(x):
return inverse_softplus(x - lower)
case "exp" | "log":
def constrain(x):
return np.exp(x) + lower
def unconstrain(x):
return np.log(x - lower)
case str() as name:
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
case other:
raise TypeError(f"Expected a method name, got {other!r}.")
else:
# upper bounded case
match method:
case "default" | "softplus":
def constrain(x):
return -softplus(-x) + upper
def unconstrain(x):
return -inverse_softplus(-(x - upper))
case "exp" | "log":
def constrain(x):
return -np.exp(-x) + upper
def unconstrain(x):
return -np.log(-x + upper)
case str() as name:
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
case other:
raise TypeError(f"Expected a method name, got {other!r}.")
self.lower = lower
self.upper = upper
self.method = method
self.inclusive = inclusive
self.epsilon = epsilon
self.constrain = constrain
self.unconstrain = unconstrain
# do this last to avoid serialization issues
match inclusive:
case "lower":
if lower is not None:
lower = lower - epsilon
case "upper":
if upper is not None:
upper = upper + epsilon
case True | "both":
if lower is not None:
lower = lower - epsilon
if upper is not None:
upper = upper + epsilon
case False | None | "none":
pass
case other:
raise ValueError(f"Unsupported value for 'inclusive': {other!r}.")
[docs]
@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Constrain":
return cls(**config)
[docs]
def get_config(self) -> dict:
return {
"lower": self.lower,
"upper": self.upper,
"method": self.method,
"inclusive": self.inclusive,
"epsilon": self.epsilon,
}
[docs]
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
# forward means data space -> network space, so unconstrain the data
return self.unconstrain(data)
[docs]
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
# inverse means network space -> data space, so constrain the data
return self.constrain(data)