optimal_transport#
- bayesflow.utils.optimal_transport(x1, x2, method='log_sinkhorn', return_assignments=False, **kwargs)[source]#
Matches elements from x2 onto x1, such that the transport cost between them is minimized, according to the method and cost matrix used.
Depending on the method used, elements in either tensor may be permuted, dropped, duplicated, or otherwise modified, such that the assignment is optimal.
Note: this is just a dispatch function that calls the appropriate optimal transport method. See the documentation of the respective method for more details.
- Parameters:
x1 – Tensor of shape (n, …) Samples from the first distribution.
x2 – Tensor of shape (m, …) Samples from the second distribution.
method – Method used to compute the transport cost. Default: ‘log_sinkhorn’
return_assignments – Whether to return the assignment indices. Default: False
kwargs – Additional keyword arguments that are passed to the optimization method.
- Returns:
Tensors of shapes (n, …) and (m, …) x1 and x2 in optimal transport permutation order.