Source code for bayesflow.adapters.transforms.squeeze
import numpy as np
from collections.abc import Sequence
from bayesflow.utils.serialization import serializable, serialize
from .elementwise_transform import ElementwiseTransform
[docs]
@serializable("bayesflow.adapters")
class Squeeze(ElementwiseTransform):
"""
Squeeze dimensions of an array.
Parameters
----------
axis : int or tuple
The axis to squeeze. As the number of batch dimensions might change, we advise using negative
numbers (i.e., indexing from the end instead of the start).
Examples
--------
shape (3, 1) array:
>>> a = np.array([[1], [2], [3]])
>>> sq = bf.adapters.transforms.Squeeze(axis=-1)
>>> sq.forward(a).shape
(3,)
It is recommended to precede this transform with a :class:`~bayesflow.adapters.transforms.ToArray` transform.
"""
def __init__(self, *, axis: int | Sequence[int]):
super().__init__()
if isinstance(axis, Sequence):
axis = tuple(axis)
self.axis = axis
[docs]
def get_config(self) -> dict:
return serialize({"axis": self.axis})
[docs]
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.squeeze(data, axis=self.axis)
[docs]
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.expand_dims(data, axis=self.axis)