Source code for bayesflow.adapters.transforms.ungroup

from .transform import Transform
from bayesflow.utils.serialization import deserialize, serializable, serialize


[docs] @serializable("bayesflow.adapters") class Ungroup(Transform): def __init__(self, key: str, prefix: str = ""): """ Ungroups the the variables in `key` from a dictionary into individual entries. Most transforms do not support nested structures, so this can be used to flatten a nested structure. It can later on be reassembled using the :py:class:`bayesflow.adapters.transforms.Group` transform. Parameters ---------- key : str The name of the variable to ungroup. The variable has to be a dictionary. prefix : str, optional An optional common prefix that will be added to the ungrouped variable names. This can be necessary to avoid duplicate names. """ super().__init__() self.key = key self.prefix = prefix self._ungrouped = None
[docs] def get_config(self) -> dict: return serialize({"key": self.key, "prefix": self.prefix, "_ungrouped": self._ungrouped})
[docs] @classmethod def from_config(cls, config: dict, custom_objects=None): config = deserialize(config, custom_objects) _ungrouped = config.pop("_ungrouped") transform = cls(**config) transform._ungrouped = _ungrouped return transform
[docs] def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]: data = data.copy() if self.key not in data and strict: raise KeyError(f"Missing key: {self.key!r}") elif self.key not in data: return data ungrouped = [] for k, v in data.pop(self.key).items(): new_key = f"{self.prefix}{k}" if new_key in data: raise ValueError( f"Encountered duplicate key during ungrouping: '{new_key}'." " Use `prefix` to specify a unique prefix that is added to the key" ) ungrouped.append(new_key) data[new_key] = v if self._ungrouped is None: self._ungrouped = sorted(ungrouped) else: self._ungrouped = sorted(list(set(self._ungrouped + ungrouped))) return data
[docs] def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]: data = data.copy() data[self.key] = {} for key in self._ungrouped: if key not in data: if strict: raise KeyError(f"Missing key: {key!r}") else: recovered_key = key[len(self.prefix) :] data[self.key][recovered_key] = data.pop(key) return data
[docs] def log_det_jac( self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs, ): return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False)