Source code for bayesflow.adapters.transforms.rename

from keras.saving import (
    register_keras_serializable as serializable,
)

from .transform import Transform


[docs] @serializable(package="bayesflow.adapters") class Rename(Transform): """ Transform to rename keys in data dictionary. Useful to rename variables to match those required by approximator. This transform can only rename one variable at a time. Parameters ---------- from_key : str Variable name that should be renamed to_key : str New variable name Examples -------- >>> adapter = ( bf.adapters.Adapter() # rename the variables to match the required approximator inputs .rename("theta", "inference_variables") .rename("x", "inference_conditions") ) """ def __init__(self, from_key: str, to_key: str): super().__init__() self.from_key = from_key self.to_key = to_key
[docs] @classmethod def from_config(cls, config: dict, custom_objects=None) -> "Rename": return cls( from_key=config["from_key"], to_key=config["to_key"], )
[docs] def get_config(self) -> dict: return {"from_key": self.from_key, "to_key": self.to_key}
[docs] def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]: data = data.copy() if strict and self.from_key not in data: raise KeyError(f"Missing key: {self.from_key!r}") elif self.from_key not in data: return data data[self.to_key] = data.pop(self.from_key) return data
[docs] def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]: data = data.copy() if strict and self.to_key not in data: raise KeyError(f"Missing key: {self.to_key!r}") elif self.to_key not in data: return data data[self.from_key] = data.pop(self.to_key) return data
[docs] def extra_repr(self) -> str: return f"{self.from_key!r} -> {self.to_key!r}"