Source code for bayesflow.metrics.functional.maximum_mean_discrepancy
import keras
from bayesflow.types import Tensor
from bayesflow.utils import issue_url
from .kernels import gaussian, inverse_multiquadratic
[docs]
def maximum_mean_discrepancy(
x: Tensor, y: Tensor, kernel: str = "inverse_multiquadratic", unbiased: bool = False, **kwargs
) -> Tensor:
"""Computes a mixture of Gaussian radial basis functions (RBFs) between the samples of x and y.
See the original paper below for details and different estimators:
Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012).
A kernel two-sample test. The Journal of Machine Learning Research, 13(1), 723-773.
https://jmlr.csail.mit.edu/papers/v13/gretton12a.html
Parameters
----------
x : Tensor of shape (num_draws_x, num_features)
Comprises `num_draws_x` Random draws from the "source" distribution `P`.
y : Tensor of shape (num_draws_y, num_features)
Comprises `num_draws_y` Random draws from the "source" distribution `Q`.
kernel : str, optional (default - "inverse_multiquadratic")
The (mixture of) kernels to be used for the MMD computation.
unbiased : bool, optional (default - False)
Whether to use the unbiased MMD estimator. Default is False.
Returns
-------
mmd : Tensor of shape (1, )
The biased or unbiased empirical maximum mean discrepancy (MMD) estimator.
"""
if kernel == "gaussian":
kernel_fn = gaussian
elif kernel == "inverse_multiquadratic":
kernel_fn = inverse_multiquadratic
else:
raise ValueError(
"For now, we only support a gaussian and an inverse_multiquadratic kernel."
f"If you need a different kernel, please open an issue at {issue_url}"
)
if keras.ops.shape(x)[1:] != keras.ops.shape(y)[1:]:
raise ValueError(
f"Expected x and y to live in the same feature space, "
f"but got {keras.ops.shape(x)[1:]} != {keras.ops.shape(y)[1:]}."
)
if unbiased:
m, n = keras.ops.shape(x)[0], keras.ops.shape(y)[0]
xx = (1.0 / (m * (m + 1))) * keras.ops.sum(kernel_fn(x, x, **kwargs))
yy = (1.0 / (n * (n + 1))) * keras.ops.sum(kernel_fn(y, y, **kwargs))
xy = (2.0 / (m * n)) * keras.ops.sum(kernel_fn(x, y, **kwargs))
else:
xx = keras.ops.mean(kernel_fn(x, x, **kwargs))
yy = keras.ops.mean(kernel_fn(y, y, **kwargs))
xy = keras.ops.mean(kernel_fn(x, y, **kwargs))
return xx + yy - 2.0 * xy