Source code for bayesflow.utils.optimal_transport.optimal_transport

import keras

from bayesflow.types import Tensor

from .log_sinkhorn import log_sinkhorn
from .sinkhorn import sinkhorn

methods = {
    "sinkhorn": sinkhorn,
    "sinkhorn_knopp": sinkhorn,
    "log_sinkhorn": log_sinkhorn,
    "log_sinkhorn_knopp": log_sinkhorn,
}


[docs] def optimal_transport( x1: Tensor, x2: Tensor, conditions: Tensor | None = None, method="sinkhorn", return_assignments=False, **kwargs ) -> tuple[Tensor, Tensor, Tensor | None, Tensor] | tuple[Tensor, Tensor, Tensor | None]: """ Match elements from ``x2`` onto ``x1`` by minimizing the transport cost. This function dispatches to a specific optimal transport method according to the selected ``method`` and cost formulation. Depending on the method used, elements in either tensor may be permuted, dropped, duplicated, or otherwise modified in order to achieve an optimal assignment. Note ---- This is a dispatch function that calls the appropriate optimal transport implementation. See the documentation of the selected method for details on the exact optimization procedure and assumptions. Parameters ---------- x1 : Tensor Tensor of shape ``(n, ...)`` containing samples from the first distribution. x2 : Tensor Tensor of shape ``(m, ...)`` containing samples from the second distribution. conditions : Tensor, optional Tensor of shape ``(k, ...)`` providing conditioning information for conditional optimal transport. If ``None``, unconditional optimal transport is performed. Default is ``None``. method : str, optional Method used to compute the optimal transport plan (e.g., ``'sinkhorn'``). Default is ``'sinkhorn'``. return_assignments : bool If ``True``, also return the assignment indices produced by the transport method. Default is ``False``. **kwargs Additional keyword arguments passed to the selected optimal transport method. Returns ------- Tuple of tensors If ``return_assignments`` is ``False``, returns three tensors of shapes ``(n, ...)`` and ``(m, ...)`` corresponding to ``x1``, ``x2``, ``conditions`` reordered according to the optimal transport solution. If ``return_assignments`` is ``True``, the reordered tensors and the corresponding assignment indices are returned as a fourth element. """ assignments = methods[method.lower()](x1, x2, conditions, **kwargs) x2 = keras.ops.take(x2, assignments, axis=0) if conditions is not None: # conditions must be resampled along with x1 conditions = keras.ops.take(conditions, assignments, axis=0) if return_assignments: return x1, x2, conditions, assignments return x1, x2, conditions