optimal_transport#
- bayesflow.utils.optimal_transport(x1: Tensor, x2: Tensor, *aux: Tensor, method: str = 'sinkhorn_knopp', **kwargs)[source]#
Matches elements from x2 onto x1, such that the transport cost between them is minimized, according to the method and cost matrix used.
Depending on the method used, elements in either tensor may be permuted, dropped, duplicated, or otherwise modified, such that the assignment is optimal.
Note: this is just a dispatch function that calls the appropriate optimal transport method. See the documentation of the respective method for more details.
- Parameters:
x1 – Tensor of shape (n, …) Samples from the first distribution.
x2 – Tensor of shape (m, …) Samples from the second distribution.
aux – Tensors of shape (n, …) Auxiliary tensors to be permuted along with x1. Note that x2 is never permuted for all currently available methods.
method – Method used to compute the transport cost. Default: ‘sinkhorn_knopp’
kwargs – Additional keyword arguments passed to the optimization method.
- Returns:
Tensors of shapes (n, …) and (m, …) x1 and x2 in optimal transport permutation order.