Source code for bayesflow.metrics.maximum_mean_discrepancy

import keras

from bayesflow.utils.serialization import deserialize, serializable, serialize
from .functional import maximum_mean_discrepancy


[docs] @serializable("bayesflow.metrics") class MaximumMeanDiscrepancy(keras.Metric): def __init__( self, name: str = "maximum_mean_discrepancy", kernel: str = "inverse_multiquadratic", unbiased: bool = False, **kwargs, ): super().__init__(name=name, **kwargs) self.mmd = self.add_variable(shape=(), initializer="zeros", name="mmd") self.kernel = kernel self.unbiased = unbiased
[docs] def update_state(self, x, y): self.mmd.assign( keras.ops.cast(maximum_mean_discrepancy(x, y, kernel=self.kernel, unbiased=self.unbiased), self.dtype) )
[docs] def result(self): return self.mmd.value
[docs] def get_config(self): base_config = super().get_config() config = {"kernel": self.kernel, "unbiased": self.unbiased} return base_config | serialize(config)
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))