Source code for bayesflow.metrics.maximum_mean_discrepancy
from functools import partial
import keras
from .functional import maximum_mean_discrepancy
[docs]
class MaximumMeanDiscrepancy(keras.Metric):
def __init__(
self,
name: str = "maximum_mean_discrepancy",
kernel: str = "inverse_multiquadratic",
unbiased: bool = False,
**kwargs,
):
super().__init__(name=name, **kwargs)
self.mmd = self.add_variable(shape=(), initializer="zeros", name="mmd")
self.mmd_fn = partial(maximum_mean_discrepancy, kernel=kernel, unbiased=unbiased)
[docs]
def update_state(self, x, y):
self.mmd.assign(keras.ops.cast(self.mmd_fn(x, y), self.dtype))
[docs]
def result(self):
return self.mmd.value