Source code for bayesflow.adapters.transforms.take

from collections.abc import Sequence
import numpy as np

from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform


[docs] @serializable(package="bayesflow.adapters") class Take(ElementwiseTransform): """ A transform to reduce the dimensionality of arrays output by the summary network Example: adapter.take("x", np.arange(0,3), axis=-1) """ def __init__(self, indices: Sequence[int], axis: int = -1): super().__init__() self.indices = indices self.axis = axis
[docs] def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: return np.take(data, self.indices, self.axis)
[docs] def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: # not a true invertible function return data
[docs] def get_config(self) -> dict: config = {"indices": self.indices, "axis": self.axis} return serialize(config)