Source code for bayesflow.adapters.transforms.expand_dims

import numpy as np
from keras.saving import (
    deserialize_keras_object as deserialize,
    register_keras_serializable as serializable,
    serialize_keras_object as serialize,
)

from .elementwise_transform import ElementwiseTransform


[docs] @serializable(package="bayesflow.adapters") class ExpandDims(ElementwiseTransform): """ Expand the shape of an array. Parameters ---------- axis : int or tuple The axis to expand. Examples -------- shape (3,) array: >>> a = np.array([1, 2, 3]) shape (2, 3) array: >>> b = np.array([[1, 2, 3], [4, 5, 6]]) >>> dat = dict(a=a, b=b) >>> ed = bf.adapters.transforms.ExpandDims("a", axis=0) >>> new_dat = ed.forward(dat) >>> new_dat["a"].shape (1, 3) >>> ed = bf.adapters.transforms.ExpandDims("a", axis=1) >>> new_dat = ed.forward(dat) >>> new_dat["a"].shape (3, 1) >>> ed = bf.adapters.transforms.ExpandDims("b", axis=1) >>> new_dat = ed.forward(dat) >>> new_dat["b"].shape (2, 1, 3) It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform. """ def __init__(self, *, axis: int | tuple): super().__init__() self.axis = axis
[docs] @classmethod def from_config(cls, config: dict, custom_objects=None) -> "ExpandDims": return cls( axis=deserialize(config["axis"], custom_objects), )
[docs] def get_config(self) -> dict: return { "axis": serialize(self.axis), }
[docs] def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: return np.expand_dims(data, axis=self.axis)
[docs] def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return np.squeeze(data, axis=self.axis)