Source code for bayesflow.adapters.transforms.one_hot
import numpy as np
from keras.saving import (
register_keras_serializable as serializable,
)
from bayesflow.utils.numpy_utils import one_hot
from .elementwise_transform import ElementwiseTransform
[docs]
@serializable(package="bayesflow.adapters")
class OneHot(ElementwiseTransform):
"""
Changes data to be one-hot encoded.
Parameters
----------
num_classes : int
Number of classes for the encoding.
"""
def __init__(self, num_classes: int):
super().__init__()
self.num_classes = num_classes
[docs]
@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "OneHot":
return cls(num_classes=config["num_classes"])
[docs]
def get_config(self) -> dict:
return {"num_classes": self.num_classes}
[docs]
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return one_hot(data, self.num_classes)
[docs]
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.argmax(data, axis=-1)