Source code for bayesflow.adapters.transforms.sqrt

import numpy as np

from bayesflow.utils.serialization import serializable

from .elementwise_transform import ElementwiseTransform


[docs] @serializable("bayesflow.adapters") class Sqrt(ElementwiseTransform): """Square-root transform a variable. Examples -------- >>> adapter = bf.Adapter().sqrt(["x"]) """
[docs] def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: return np.sqrt(data)
[docs] def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return np.square(data)
[docs] def get_config(self) -> dict: return {}
[docs] def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: ldj = -0.5 * np.log(data) - np.log(2) if inverse: ldj = -ldj return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))