Pytorch Implementation of Wasserstein Distance

INFO

This is a pytorch implementation of the Scipy.stats Wasserstein_distance

This function is fully compatable with back propagation!!!

The Wasserstein distance, also called the Earth mover’s distance or the optimal transport distance, is a similarity metric between two probability distributions 1. In the discrete case, the Wasserstein distance can be understood as the cost of an optimal transport plan to convert one distribution into the other. The cost is calculated as the product of the amount of probability mass being moved and the distance it is being moved. A brief and intuitive introduction can be found at 2.

Parameters

u_values : 1d array_like A sample from a probability distribution or the support (set of all possible values) of a probability distribution. Each element is an observation or possible value.

v_values : 1d array_like A sample from or the support of a second distribution.

u_weights, v_weights : 1d array_like, optional Weights or counts corresponding with the sample or probability masses corresponding with the support values. Sum of elements must be positive and finite. If unspecified, each value is assigned the same weight.

Returns

distance : float The computed distance between the distributions.


Notes

Given two 1D probability mass functions, \(u\) and \(v\), the first Wasserstein distance between the distributions is:

\[l_1 (u, v) = \inf_{\pi \in \Gamma (u, v)} \int_{\mathbb{R} \times \mathbb{R}} |x-y| \mathrm{d} \pi (x, y)\]

where \(\Gamma (u, v)\) is the set of (probability) distributions on \(\mathbb{R} \times \mathbb{R}\) whose marginals are \(u\) and \(v\) on the first and second factors respectively. For a given value \(x\), \(u(x)\) gives the probabilty of \(u\) at position \(x\), and the same for \(v(x)\).

If \(U\) and \(V\) are the respective CDFs of \(u\) and \(v\), this distance also equals to:

\[l_1(u, v) = \int_{-\infty}^{+\infty} |U-V|\]

See 3 for a proof of the equivalence of both definitions.

The input distributions can be empirical, therefore coming from samples whose values are effectively inputs of the function, or they can be seen as generalized functions, in which case they are weighted sums of Dirac delta functions located at the specified values.


Code

def torch_wasserstein_distance(u_values, v_values, u_weights=None, v_weights=None):
    # Ensure that the input tensors are batched
    assert u_values.dim() == 2 and v_values.dim() == 2, "Input tensors must be 2-dimensional (batch_size, num_values)"

    batch_size, u_size = u_values.shape
    _, v_size = v_values.shape

    # Sort the values
    u_sorter = torch.argsort(u_values, dim=1)
    v_sorter = torch.argsort(v_values, dim=1)

    # Concatenate and sort all values for each batch
    all_values = torch.cat((u_values, v_values), dim=1)
    all_values, _ = torch.sort(all_values, dim=1)
    # Compute differences between successive values
    deltas = torch.diff(all_values, dim=1)

    # Get the respective positions of the values of u and v among the values of both distributions
    all_continue = all_values[:, :-1].contiguous()
    u_cdf_indices = torch.searchsorted(u_values.gather(1, u_sorter).contiguous(), all_continue, right=True)
    v_cdf_indices = torch.searchsorted(v_values.gather(1, v_sorter).contiguous(), all_continue, right=True)

    # Calculate the CDFs of u and v using their weights, if specified
    if u_weights is None:
        u_cdf = u_cdf_indices.float() / u_size
    else:
        u_sorted_cumweights = torch.cat((torch.zeros((batch_size, 1)), torch.cumsum(u_weights.gather(1, u_sorter), dim=1)), dim=1)
        u_cdf = u_sorted_cumweights.gather(1, u_cdf_indices) / u_sorted_cumweights[:, -1].unsqueeze(1)

    if v_weights is None:
        v_cdf = v_cdf_indices.float() / v_size
    else:
        v_sorted_cumweights = torch.cat((torch.zeros((batch_size, 1)), torch.cumsum(v_weights.gather(1, v_sorter), dim=1)), dim=1)
        v_cdf = v_sorted_cumweights.gather(1, v_cdf_indices) / v_sorted_cumweights[:, -1].unsqueeze(1)

    return torch.sum(torch.abs(u_cdf - v_cdf) * deltas, dim=1)

Test

a = np.random.rand(1, 5)
b = np.random.rand(1, 5)
sci_wd = wasserstein_distance(a[0], b[0])
torch_a = torch.tensor(a, dtype=torch.float32,requires_grad=True)
torch_b = torch.tensor(b, dtype=torch.float32,requires_grad=True)
torch_wd = torch_wasserstein_distance(torch_a, torch_b)
print(f"a: {a}")
print(f"b: {b}")
print(f"SciPy Wasserstein distance: {sci_wd}")
print(f"Torch Wasserstein distance: {torch_wd}")
print(f"Error between SciPy and Torch: {np.abs(sci_wd - torch_wd.item())}")

output:

a: [[0.26129606 0.23081598 0.78036461 0.33665042 0.11398108]]
b: [[0.37317387 0.907402   0.36485701 0.63183274 0.40284714]]
SciPy Wasserstein distance: 0.19140092551231483
Torch Wasserstein distance: tensor([0.1914], grad_fn=<SumBackward1>)
Error between SciPy and Torch: 2.5029183420288703e-08

References

[1] “Wasserstein metric”, https://en.wikipedia.org/wiki/Wasserstein_metric

[2] Lili Weng, “What is Wasserstein distance?”, Lil’log,https://lilianweng.github.io/posts/2017-08-20-gan/#what-is-wasserstein-distance.

[3] Ramdas, Garcia, Cuturi “On Wasserstein Two Sample Testing and Related Families of Nonparametric Tests” (2015).:arXiv:1509.02237.




If you found this useful, please cite this as:
Xiao, Zhihua (Jul 2024). Pytorch Implementation of Wasserstein Distance. https://forkxz.github.io/blog/2024/Wasserstein/.

or as a BibTeX entry:

@article{xiao2024pytorch-implementation-of-wasserstein-distance,
  title   = {Pytorch Implementation of Wasserstein Distance},
  author  = {Xiao, Zhihua},
  year    = {2024},
  month   = {Jul},
  url     = {https://forkxz.github.io/blog/2024/Wasserstein/}
}



    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • Triton DropConnect Kernel
  • 10x Faster Matrix Determination Using Numba Instead of Numpy
  • What are Diffusion Models?