Source code for bayesflow.adapters.transforms.rename
from keras.saving import (
register_keras_serializable as serializable,
)
from .transform import Transform
[docs]
@serializable(package="bayesflow.adapters")
class Rename(Transform):
"""
Transform to rename keys in data dictionary. Useful to rename variables to match those required by
approximator. This transform can only rename one variable at a time.
Parameters
----------
from_key : str
Variable name that should be renamed
to_key : str
New variable name
Examples
--------
>>> adapter = (
bf.adapters.Adapter()
# rename the variables to match the required approximator inputs
.rename("theta", "inference_variables")
.rename("x", "inference_conditions")
)
"""
def __init__(self, from_key: str, to_key: str):
super().__init__()
self.from_key = from_key
self.to_key = to_key
[docs]
@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Rename":
return cls(
from_key=config["from_key"],
to_key=config["to_key"],
)
[docs]
def get_config(self) -> dict:
return {"from_key": self.from_key, "to_key": self.to_key}
[docs]
def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:
data = data.copy()
if strict and self.from_key not in data:
raise KeyError(f"Missing key: {self.from_key!r}")
elif self.from_key not in data:
return data
data[self.to_key] = data.pop(self.from_key)
return data
[docs]
def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]:
data = data.copy()
if strict and self.to_key not in data:
raise KeyError(f"Missing key: {self.to_key!r}")
elif self.to_key not in data:
return data
data[self.from_key] = data.pop(self.to_key)
return data