Source code for bayesflow.metrics.maximum_mean_discrepancy

from functools import partial

import keras

from .functional import maximum_mean_discrepancy


[docs] class MaximumMeanDiscrepancy(keras.Metric): def __init__( self, name: str = "maximum_mean_discrepancy", kernel: str = "inverse_multiquadratic", unbiased: bool = False, **kwargs, ): super().__init__(name=name, **kwargs) self.mmd = self.add_variable(shape=(), initializer="zeros", name="mmd") self.mmd_fn = partial(maximum_mean_discrepancy, kernel=kernel, unbiased=unbiased)
[docs] def update_state(self, x, y): self.mmd.assign(keras.ops.cast(self.mmd_fn(x, y), self.dtype))
[docs] def result(self): return self.mmd.value