Source code for bayesflow.adapters.transforms.one_hot

import numpy as np
from keras.saving import (
    register_keras_serializable as serializable,
)

from bayesflow.utils.numpy_utils import one_hot
from .elementwise_transform import ElementwiseTransform


[docs] @serializable(package="bayesflow.adapters") class OneHot(ElementwiseTransform): """ Changes data to be one-hot encoded. Parameters ---------- num_classes : int Number of classes for the encoding. """ def __init__(self, num_classes: int): super().__init__() self.num_classes = num_classes
[docs] @classmethod def from_config(cls, config: dict, custom_objects=None) -> "OneHot": return cls(num_classes=config["num_classes"])
[docs] def get_config(self) -> dict: return {"num_classes": self.num_classes}
[docs] def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: return one_hot(data, self.num_classes)
[docs] def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return np.argmax(data, axis=-1)