Source code for bayesflow.metrics.root_mean_squard_error
import keras
from bayesflow.utils.serialization import deserialize, serializable
from .functional import root_mean_squared_error
[docs]
@serializable("bayesflow.metrics")
class RootMeanSquaredError(keras.metrics.MeanMetricWrapper):
def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs):
super().__init__(root_mean_squared_error, name=name, dtype=dtype, **kwargs)
[docs]
def get_config(self):
base_config = super().get_config()
# fn is fixed and passed directly in the constructor
base_config.pop("fn")
return base_config
[docs]
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))