Source code for bayesflow.adapters.transforms.expand_dims
import numpy as np
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
from .elementwise_transform import ElementwiseTransform
[docs]
@serializable(package="bayesflow.adapters")
class ExpandDims(ElementwiseTransform):
"""
Expand the shape of an array.
Parameters
----------
axis : int or tuple
The axis to expand.
Examples
--------
shape (3,) array:
>>> a = np.array([1, 2, 3])
shape (2, 3) array:
>>> b = np.array([[1, 2, 3], [4, 5, 6]])
>>> dat = dict(a=a, b=b)
>>> ed = bf.adapters.transforms.ExpandDims("a", axis=0)
>>> new_dat = ed.forward(dat)
>>> new_dat["a"].shape
(1, 3)
>>> ed = bf.adapters.transforms.ExpandDims("a", axis=1)
>>> new_dat = ed.forward(dat)
>>> new_dat["a"].shape
(3, 1)
>>> ed = bf.adapters.transforms.ExpandDims("b", axis=1)
>>> new_dat = ed.forward(dat)
>>> new_dat["b"].shape
(2, 1, 3)
It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform.
"""
def __init__(self, *, axis: int | tuple):
super().__init__()
self.axis = axis
[docs]
@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "ExpandDims":
return cls(
axis=deserialize(config["axis"], custom_objects),
)
[docs]
def get_config(self) -> dict:
return {
"axis": serialize(self.axis),
}
[docs]
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.expand_dims(data, axis=self.axis)
[docs]
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.squeeze(data, axis=self.axis)