Wasserstein distance 구현하기
Contents
Wasserstein distance 구현하기#
이름부터 무시무시한 Wassertstein distance. 이름만 무서우면 다행이지만 이 녀석의 정의 또한 무섭다. Wasserstein distance는 두 확률 분포 사이의 거리를 측정할 때 사용된다. 그런데 우리가 실제로 갖고 있는 것은 어느 확률 분포에서 샘플링된지 모르는 데이터셋이며, 두 데이터셋 사이의 Wasserstein distance를 구하는 것이 목표이다. 이번 포스팅에서는 확률 분포가 아닌 두 데이터셋이 주어졌을 때 Wasserstein distance를 계산하는 방법에 대해서 알아본다.
포스팅은 다음과 같이 구성되어 있다. 먼저, Wasserstein distance의 정의를 살펴본다. 다음으로 1차원 데이터에 한정하여 정의에 해당하는 값을 해석적으로 구해볼 것이다. 그리고 그 값을 코드로 구현하여 거리를 직접 계산해볼 것이다.
-Wasserstein distance의 정의#
먼저, Wasserstein distance의 정의에 대해서 알아보자. 위키피디아의 정의를 그대로 따왔다.
Let
where
위의 정의를 온전히 이해하기 위해서는 수학과 대학원 확률론 지식 또는 최소 대학원 해석학 지식이 필요하다. 물론 나는 없다. 그래도 공부한 것을 바탕으로 정의를 읽어보자면 다음과 같다 (어지러우면 다음으로 문단으로 넘어가도 좋다).
Wasserstein distance는 두 probability measure 뮤
와 누 에 대해서 정의되는 것이다. Probability measure란 각 사건에 확률을 부여하는 함수 또는 규칙이며, 그냥 probability distribution이라고 생각해도 좋다. 조금 더 과감히 말하면 probability density function 또는 probability mass function으로 생각해도 좋다.두 Probability measure의 sample space에 있는 원소들을 짝짓는 모든 coupling
중에 의 최소값을 찾으면 그것의 제곱이 -Wasserstein distance이다. 여기서 coupling을 joint distribution으로 이해해도 좋다.Wasserstein distance의 정의를 Earth mover’s distance (흙 옮기는 기계)나 optimal transport problem (최적 운송 문제) 관점으로 해석할 수 있다. 한 확률 분포에 있는 mass를 최소의 비용으로 다른 확률 분포로 운송하는 것이다. 한 확률 분포의
위치에 있는 mass를 다른 확률 분포의 위치로 옮길 때 발생하는 비용 의 기댓값이 최소로 되도록 coupling을 찾는 것이다.
1차원 sample space 가정 등 여러 가지를 가정하면, 다음과 같이 우리에게 익숙한 용어로 정의를 다시 적어줄 수 있다.
For any two real-valued random variables
where
확률 용어에 익숙한 분들이라면 위 단순화된 정의는 쉽게 이해될 것이다.
확률 변수
1차원 데이터일 때 정의에 대한 analytic solution#
정의에 따라 Wasserstein distance를 계산한다면 이 세상의 모든 joint distribution
For any two real-valued random variables
where
더 이상 infimum을 구할 필요 없이 각 확률 변수의 CDF의 역함수
Note
참고로 CDF
The general quantile function
쉽게 말하면, CDF 값이 유지되는 구간에서는 가장 맨처음 값을 역함수 값으로 설정하겠다는 것이다.
Wasserstein distance 정의의 analytic solution이 식 (3)이 되는 것에 대한 증명은 이 글 가장 마지막에 남겨 놓을 예정이다.
식 (3) 덕분에 1차원 데이터에 대해서는 CDF를 찾아서 적분만 해주면 되게 된다…! 문제가 훨씬 쉬워졌지만, 여전히 실제 데이터에 대해 Wasserstein distance를 계산하기에는 다음 2가지 문제점이 있다.
우리가 갖고 있는 것은 확률 분포가 아니라 데이터인데 어떻게 CDF를 찾아야 할까?
CDF를 찾았다고 해도 적분은 어떻게 할까?
Wasserstein distance 구현하기#
먼저, 우리가 갖고 있는 데이터의 CDF를 찾는 것은 쉽다. 우리는 흔히 데이터의 분포를 보기 위해 histogram을 그려본다. 이 histogram을 데이터의 확률 분포로 보는 것이다. 데이터가 있는 부분에 상대빈도수만큼 확률을 부여하는 measure를 empirical measure라고 부른다. 직관적으로도 이해가 되기 때문에 정의를 알 필요는 없지만, 굳이 적어보면 다음과 같다.
Empirical measure
Let
where
그럼, 주어진 1차원 데이터들의 CDF는 어떻게 계산할까? 그냥 데이터를 오름차순으로 정렬하고, 순서대로 상대빈도수를 계속 더해나가면 된다. 이쯤에서 코드 구현에 사용할 예제 데이터를 보자. x는 0, 2, 4, 6, 10에 노이즈를 추가한 것이고, y는 1, 3, 5, 7, 9에 노이즈를 추가한 것이다. 그리고 순서를 뒤죽박죽 섞어주었다.
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(0)
np.set_printoptions(precision=2)
x = np.array([0.0, 2.0, 4.0, 6.0, 8.0, 10.0]) + np.random.rand(6)
y = np.array([1.0, 3.0, 5.0, 7.0, 9.0]) + np.random.rand(5)
np.random.shuffle(x)
np.random.shuffle(y)
print(x)
print(y)
[10.65 4.6 6.54 8.42 2.72 0.55]
[5.96 7.38 9.79 3.89 1.44]
x
와 y
에 대한 CDF를 구하는 방법은 굉장히 쉽다. 먼저 정렬하고, 순서대로
x_sorted = np.sort(x)
cum_x = (np.arange(len(x)) + 1) / len(x)
print(x_sorted)
print(cum_x)
[ 0.55 2.72 4.6 6.54 8.42 10.65]
[0.17 0.33 0.5 0.67 0.83 1. ]
y_sorted = np.sort(y)
cum_y = (np.arange(len(y)) + 1) / len(y)
print(y_sorted)
print(cum_y)
[1.44 3.89 5.96 7.38 9.79]
[0.2 0.4 0.6 0.8 1. ]
아직 끝난 것은 아니다. 우리가 갖고 있는 것은 여전히
plt.figure()
plt.plot(x_sorted, cum_x, 'bo')
plt.plot(y_sorted, cum_y, 'rx')
plt.grid()
plt.xlim(-0.5, 10.9)
plt.ylim(-0.03, 1.03)
plt.show()

하지만 우리가 원하는 CDF는 다음 그림과 같다.

우리의 목표는 두 CDF 사이의 차이가 발생하는 영역의 넓이를 계산하는 것이다. 그런데 위 그림을 보니, 영역들이 모두 사각형인 것을 확인할 수 있다.
데이터가 주어질 경우, CDF 사이의 차이가 발생하는 영역은 사각형이다.
사각형의 넓이를 더 쉽게 구하기 위해서는
x
데이터와y
데이터를 모두 합쳐서 보는 것이 좋다.사각형의 가로 길이는
all_values[i+1] - all_values[i]
로 쉽게 구할 수 있다.
먼저 x
와 y
를 합쳐놓자.
all_values = np.concatenate((x, y))
all_values.sort()
print(all_values)
[ 0.55 1.44 2.72 3.89 4.6 5.96 6.54 7.38 8.42 9.79 10.65]
각 사각형의 가로 길이는 np.diff()
함수를 통해서 쉽게 구할 수 있다.
deltas = np.diff(all_values)
deltas
array([0.89, 1.28, 1.18, 0.71, 1.36, 0.58, 0.84, 1.04, 1.37, 0.85])
하지만, 이렇게 되면 어떤 점이 x
에서 왔는지 y
에서 왔는지 모른다. a.searchsorted(v)
메서드는 v
배열의 각 원소가 정렬되어 있는 a
배열에서 크기 순서상 어느 인덱스에 위치해야 하는지 알려준다. 즉, 아래 코드에서 x_cdf_indices
는 all_values
배열의 각 값이 x_sorted
배열에서 어디에 위치해야 하는지 적어놓은 배열이다.
all_values
에서 x
배열에서 온 데이터를 마주칠 때마다 +1을 해준 배열이 된다.
x_cdf_indices = x_sorted.searchsorted(all_values[:-1], 'right')
x_cdf_indices
array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5])
y_cdf_indices = y_sorted.searchsorted(all_values[:-1], 'right')
y_cdf_indices
array([0, 1, 1, 2, 2, 3, 3, 4, 4, 5])
CDF는 다음과 같이 구할 수 있다.
x_cdf = x_cdf_indices / len(x)
y_cdf = y_cdf_indices / len(y)
print(all_values[:-1])
print(x_cdf)
print(y_cdf)
[0.55 1.44 2.72 3.89 4.6 5.96 6.54 7.38 8.42 9.79]
[0.17 0.17 0.33 0.33 0.5 0.5 0.67 0.67 0.83 0.83]
[0. 0.2 0.2 0.4 0.4 0.6 0.6 0.8 0.8 1. ]
계산한 가로 길이 (deltas
)와 CDFs (x_cdf
, y_cdf
)를 사용해서
p = 1
if p == 1:
d = np.sum(np.multiply(np.abs(x_cdf - y_cdf), deltas))
elif p == 2:
d = np.sqrt(np.sum(np.multiply(np.square(x_cdf - y_cdf), deltas)))
else:
d = np.power(np.sum(np.multiply(np.power(np.abs(x_cdf - y_cdf), p),
deltas)), 1/p)
print(d)
0.9717676936051363
지금까지의 과정을 모두 합쳐서 함수로 만들어 보면 다음과 같다. 이는 실제로 SciPy에 있는 wasserstein_distance
함수의 간략화된 버전이다.
def _cdf_distance(p, x, y):
"""
From https://github.com/scipy/scipy/blob/v1.10.1/scipy/stats/_stats_py.py#L9165
"""
x_sorted = np.sort(x)
y_sorted = np.sort(y)
all_values = np.concatenate((x, y))
all_values.sort(kind='mergesort')
# Compute the differences between pairs of successive values of u and v.
deltas = np.diff(all_values)
# Get the respective positions of the values of u and v among the values of
# both distributions.
x_cdf_indices = x_sorted.searchsorted(all_values[:-1], 'right')
y_cdf_indices = y_sorted.searchsorted(all_values[:-1], 'right')
# Calculate the CDFs of u and v using their weights, if specified.
x_cdf = x_cdf_indices / x.size
y_cdf = y_cdf_indices / y.size
# Compute the value of the integral based on the CDFs.
# If p = 1 or p = 2, we avoid using np.power, which introduces an overhead
# of about 15%.
if p == 1:
return np.sum(np.multiply(np.abs(x_cdf - y_cdf), deltas))
if p == 2:
return np.sqrt(np.sum(np.multiply(np.square(x_cdf - y_cdf), deltas)))
return np.power(np.sum(np.multiply(np.power(np.abs(x_cdf - y_cdf), p),
deltas)), 1/p)
from scipy.stats import wasserstein_distance
print(wasserstein_distance(x, y))
0.9717676936051363
이번 포스팅에서는 1차원 데이터에 대해서 두 데이터셋의
1차원을 넘어 고차원에 대해서도 어떻게 구현되는지 알아보았는데, 실제로 최적화 문제를 풀게 된다.
한 데이터셋
식 (3) 증명#
Coming soon!
참고문헌#
[1] https://en.wikipedia.org/wiki/Wasserstein_metric
[2] Marc G. Bellemare and Will Dabney and Mark Rowland, Distributional Reinforcement Learning , MIT Press, 2022. https://www.distributional-rl.org/contents/chapter4.html
[3] Ramdas A, Trillos NG, Cuturi M. On Wasserstein Two-Sample Testing and Related Families of Nonparametric Tests. Entropy. 2017; 19(2):47. https://doi.org/10.3390/e19020047