from collections.abc import Sequence
import numpy as np
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
from .transform import Transform
[docs]
@serializable(package="bayesflow.adapters")
class Broadcast(Transform):
"""
Broadcasts arrays or scalars to the shape of a given other array.
Parameters
----------
keys : sequence of str,
Input a list of strings, where the strings are the names of data variables.
to : str
Name of the data variable to broadcast to.
expand : str or int or tuple, optional
Where should new dimensions be added to match the number of dimensions in `to`?
Can be "left", "right", or an integer or tuple containing the indices of the new dimensions.
The latter is needed if we want to include a dimension in the middle, which will be required
for more advanced cases. By default we expand left.
exclude : int or tuple, optional
Which dimensions (of the dimensions after expansion) should retain their size,
rather than being broadcasted to the corresponding dimension size of `to`?
By default we exclude the last dimension (usually the data dimension) from broadcasting the size.
squeeze : int or tuple, optional
Axis to squeeze after broadcasting.
Notes
-----
Important: Do not broadcast to variables that are used as inference variables
(i.e., parameters to be inferred by the networks). The adapter will work during training
but then fail during inference because the variable being broadcasted to is not available.
Examples
--------
shape (1, ) array:
>>> a = np.array((1,))
shape (2, 3) array:
>>> b = np.array([[1, 2, 3], [4, 5, 6]])
shape (2, 2, 3) array:
>>> c = np.array([[[1, 2, 3], [4, 5, 6]], [[4, 5, 6], [1, 2, 3]]])
>>> dat = dict(a=a, b=b, c=c)
>>> bc = bf.adapters.transforms.Broadcast("a", to="b")
>>> new_dat = bc.forward(dat)
>>> new_dat["a"].shape
(2, 1)
>>> bc = bf.adapters.transforms.Broadcast("a", to="b", exclude=None)
>>> new_dat = bc.forward(dat)
>>> new_dat["a"].shape
(2, 3)
>>> bc = bf.adapters.transforms.Broadcast("b", to="c", expand=1)
>>> new_dat = bc.forward(dat)
>>> new_dat["b"].shape
(2, 2, 3)
It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform.
"""
def __init__(
self,
keys: Sequence[str],
*,
to: str,
expand: str | int | tuple = "left",
exclude: int | tuple = -1,
squeeze: int | tuple = None,
):
super().__init__()
self.keys = keys
self.to = to
if isinstance(expand, int):
expand = (expand,)
self.expand = expand
if isinstance(exclude, int):
exclude = (exclude,)
self.exclude = exclude
self.squeeze = squeeze
[docs]
@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
# Deserialize turns tuples to lists, undo it if necessary
exclude = deserialize(config["exclude"], custom_objects)
exclude = tuple(exclude) if isinstance(exclude, list) else exclude
expand = deserialize(config["expand"], custom_objects)
expand = tuple(expand) if isinstance(expand, list) else expand
squeeze = deserialize(config["squeeze"], custom_objects)
squeeze = tuple(squeeze) if isinstance(squeeze, list) else squeeze
return cls(
keys=deserialize(config["keys"], custom_objects),
to=deserialize(config["to"], custom_objects),
expand=expand,
exclude=exclude,
squeeze=squeeze,
)
[docs]
def get_config(self) -> dict:
return {
"keys": serialize(self.keys),
"to": serialize(self.to),
"expand": serialize(self.expand),
"exclude": serialize(self.exclude),
"squeeze": serialize(self.squeeze),
}
# noinspection PyMethodOverriding
[docs]
def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
target_shape = data[self.to].shape
data = data.copy()
for k in self.keys:
# ensure that .shape is defined
data[k] = np.asarray(data[k])
len_diff = len(target_shape) - len(data[k].shape)
if self.expand == "left":
data[k] = np.expand_dims(data[k], axis=tuple(np.arange(0, len_diff)))
elif self.expand == "right":
data[k] = np.expand_dims(data[k], axis=tuple(-np.arange(1, len_diff + 1)))
elif isinstance(self.expand, tuple):
if len(self.expand) is not len_diff:
raise ValueError("Length of `expand` must match the length difference of the involed arrays.")
data[k] = np.expand_dims(data[k], axis=self.expand)
new_shape = target_shape
if self.exclude is not None:
new_shape = np.array(new_shape, dtype=int)
old_shape = np.array(data[k].shape, dtype=int)
exclude = list(self.exclude)
new_shape[exclude] = old_shape[exclude]
new_shape = tuple(new_shape)
data[k] = np.broadcast_to(data[k], new_shape)
if self.squeeze is not None:
data[k] = np.squeeze(data[k], axis=self.squeeze)
return data
# noinspection PyMethodOverriding
[docs]
def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
# TODO: add inverse
# we will likely never actually need the inverse broadcasting in practice
# so adding this method is not high priority
return data