最適輸送

Sinkhorn アルゴリズム

Computation
P(X)
Author

司馬 博文

Published

3/13/2024

The goal of this practical session is to introduce computational optimal transport (OT) in Python. you will familiarize yourself with OT by: 1. Computing “exact” unregularized optimal transport, using the Python library POT (Python Optimal Transport). 2. Computing entropic optimal transport, using first your own version of the Sinkhorn algorithm, then the Python library OTT-JAX.

In order to lighten the reading of the notebook, we place the functions allowing to perform plots, that are used multiple times, in the section below.

Code
def plot_weighted_points(
    ax,
    x, a,
    y, b,
    title=None, x_label=None, y_label=None
):
  ax.scatter(x[:,0], x[:,1], s=5000*a, c='r', edgecolors='k', label=x_label)
  ax.scatter(y[:,0], y[:,1], s=5000*b, c='b', edgecolors='k', label=y_label)
  for i in range(np.shape(x)[0]):
      ax.annotate(str(i+1), (x[i,0], x[i,1]),fontsize=30,color='black')
  for i in range(np.shape(y)[0]):
      ax.annotate(str(i+1), (y[i,0], y[i,1]),fontsize=30,color='black')
  if x_label is not None or y_label is not None:
    ax.legend(fontsize=20)
  ax.axis('off')
  ax.set_title(title, fontsize=25)

def plot_assignement(
    ax,
    x, a,
    y, b,
    optimal_plan,
    title=None, x_label=None, y_label=None
):
  plot_weighted_points(
    ax=ax,
    x=x, a=a,
    y=y, b=b,
    title=None,
    x_label=x_label, y_label=y_label
  )
  for i in range(optimal_plan.shape[0]):
      for j in range(optimal_plan.shape[1]):
          ax.plot([x[i,0], y[j,0]], [x[i,1], y[j,1]], c='k', lw=30*optimal_plan[i,j], alpha=0.8)
  ax.axis('off')
  ax.set_title(title, fontsize=30)

def plot_assignement_1D(
    ax,
    x, y,
    title=None
):
  plot_points_1D(
    ax,
    x, y,
    title=None
  )
  x_sorted = np.sort(x)
  y_sorted = np.sort(y)
  assert len(x) == len(y), "x and y must have the same shape."
  for i in range(len(x)):
    ax.hlines(
        y=0,
        xmin=min(x_sorted[i], y_sorted[i]),
        xmax=max(x_sorted[i], y_sorted[i]),
        color='k',
        lw=10
    )
  ax.axis('off')
  ax.set_title(title, fontsize=30)

def plot_points_1D(
    ax,
    x, y,
    title=None
):
  n = len(x)
  a = np.ones(n) / n
  ax.scatter(x, np.zeros(n), s=1000*a, c='r')
  ax.scatter(y, np.zeros(n), s=1000*b, c='b')
  min_val = min(np.min(x), np.min(y))
  max_val = max(np.max(x), np.max(y))
  for i in range(n):
      ax.annotate(str(i+1), xy=(x[i], 0.005), size=30, color='r', ha='center')
  for j in range(n):
      ax.annotate(str(j+1), xy=(y[j], 0.005), size=30, color='b', ha='center')
  ax.axis('off')
  ax.plot(np.linspace(min_val, max_val, 10), np.zeros(10))
  ax.set_title(title, fontsize=30)

def plot_consistency(
    ax,
    reg_strengths,
    plan_diff, distance_diff
):
  ax[0].loglog(reg_strengths, plan_diff, lw=4)
  ax[0].set_ylabel('$||P^* - P_\epsilon^*||_F$', fontsize=25)
  ax[1].tick_params(which='both', size=20)
  ax[0].grid(ls='--')
  ax[1].loglog(reg_strengths, distance_diff, lw=4)
  ax[1].set_xlabel('Regularization Strength $\epsilon$', fontsize=25)
  ax[1].set_ylabel(r'$ 100 \cdot \frac{\langle C, P^*_\epsilon \rangle - \langle C, P^* \rangle}{\langle C, P^* \rangle} $', fontsize=25)
  ax[1].tick_params(which='both', size=20)
  ax[1].grid(ls='--')

1 I: Exact Optimal Transport with POT

1.1 I.1 Reminders on Discrete Optimal Transport

Optimal Transport is a theory that allows us to compare two (weighted) points clouds \((x, a)\) and \((y, b)\), where \(x \in \mathbb{R}^{n \times d}\) and \(y \in \mathbb{R}^{m \times d}\) are the locations of the \(n\) (resp. \(m\)) points in dimension \(d\), and \(a \in \mathbb{R}^n\), \(b \in \mathbb{R}^m\) are the weights. We ask that the total weights sum to one, i.e. \(\sum_{i=1}^n a_i = \sum_{j=1}^m b_j = 1\).

The basic idea of Optimal Transport is to “transport” the mass located at points \(x\) to the mass located at points \(y\).

Let us denote by \(U(a,b) := \left\{ P \in \mathbb{R}^{n \times m} \,|\, P \geq 0, \sum_{j=1}^m P_{ij} = a_i, \sum_{i=1}^n P_{ij} = b_j\right\}\) the set of admissible transport plans.

If \(P \in U(a,b)\), the quantity \(P_{ij} \geq 0\) should be regarded as the mass transported from point \(x_i\) to point \(y_j\). For this reason, it is called a transport plan.

We will also consider a cost function \(c : \mathbb{R}^d \times \mathbb{R}^d → \mathbb{R}\) and the associated cost matrix \(C = [c(x_i, y_j)]_{1\leq i,j \leq n,m}\in \mathbb{R}^{n \times m}\), containing the pairwise costs between the points of each point cloud \(x\) and \(y\). The quantity \(C_{ij}\) should be regarded as the cost paid for transporting one unit of mass from \(x_i\) to \(y_j\). This cost is usually computed using the positions \(x_i\) and \(y_j\), for example \(C_{ij} = \|x_i - y_j\|_2\) or \(C_{ij} = \|x_i - y_j\|_2^2\), but may be more exotic in some cases.

Then transporting mass according to \(P \in U(a,b)\) has a total cost of \(\sum_{i,j=1}^n P_{ij} C_{ij}\).

In “Optimal Transport”, there is the word Optimal. Indeed, we want to find a transport plan \(P \in U(a,b)\) that will minimize its total cost. In other words, we want to solve \[ \min_{P \in U(a,b)} \sum_{i,j=1}^n C_{ij }P_{ij} = \min_{P \in U(a,b)} ⟨C, P⟩. \]

This problem is a Linear Program: the objective function is linear in the variable \(P\), and the constraints are linear in \(P\). We can thus solve this problem using classical Linear Programming algorithms, such as the simplex algorithm.

If \(P^*\) is a solution to the Optimal Transport problem, we will say that \(P^*\) is an optimal transport plan between \((x, a)\) and \((y, b)\), and that \(\sum_{ij} P^*_{ij} C_{ij}\) is the optimal transport distance between \((x, a)\) and \((y, b)\): it is the minimal amount of “energy” that is necessary to transport the initial mass located at points \(x\) to the target mass lcoated at points \(y\).

Usually, we represent the weighted point clouds by probability measures \(\mu = \sum_{i=1}^n a_i \delta_{x_i}\) and \(\nu = \sum_{j=1}^m b_j \delta_{y_j}\). Solving the above problem, we then say that we solve the optimal transport problem between the measures \(\mu\) and \(\nu\). Moreover, we note: \[ W_c(\mu, \nu) = \min_{P \in U(a,b)} ⟨C, P⟩. \]

1.2 I.2 Computing Optimal “Croissant” Transport

1.2.1 Install

First, you need to install a few packages:

Code
%pip install POT
%pip install cloudpickle
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: POT in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (0.9.3)
Requirement already satisfied: numpy>=1.16 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from POT) (1.26.1)
Requirement already satisfied: scipy>=1.6 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from POT) (1.11.3)
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: cloudpickle in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (3.0.0)
Note: you may need to restart the kernel to use updated packages.

Then, load the required packages.

Code
import ot
import numpy as np
import os
from typing import Callable
import matplotlib.pyplot as plt

Finally, connect the notebok to your drive to load some data that will be used for the experiments.

1.2.2 Formalization of the problem

We will solve the Bakeries/Cafés problem of transporting croissants from a number of Bakeries to Cafés.

We use fictional positions, production and sale numbers. We impose that the total croissant production is equal to the number of croissants sold, so that Bakeries and Cafés can be represented as measures with the same total mass. Then, up to normalization, they can be processed as probability measures.

Mathematically, we have acess to the position of the \(m\) Bakeries as points in \(\mathbb{R}^2\) via \(x \in \mathbb{R}^{n \times 2}\) and their respective production via \(a \in \mathbb{R}^m\) which describe the source point cloud. The Cafés where the croissants are sold are also defined by their position \(y \in \mathbb{R}^{m \times 2}\) and the quantity of croissants sold by \(b \in \mathbb{R}^{m}\).

Afterwards, the Bakeries are represented by the probability measure \(\mu = \sum_{i=1}^n a_i \delta_{x_i}\) and the Cafés by \(\nu = \sum_{j=1}^n b_j \delta_{y_j}\). Calculating the optimal assignment of the croissants delivered by the Bakeries to the Cafés remains to calculating the optimal transport between the probability measures \(\mu\) and \(\nu\).

Let’s download the data and check that the total croissant production is equal to the number of croissants sold.

Code
# Load the data
import pickle
from urllib.request import urlopen
import cloudpickle as cp

croissants = cp.load(urlopen('https://marcocuturi.net/data/croissants.pickle'))

bakery_pos = croissants['bakery_pos']
bakery_prod = croissants['bakery_prod']
cafe_pos = croissants['cafe_pos']
cafe_prod = croissants['cafe_prod']

print('Bakery productions =', bakery_prod)
print('Total number of croissants =', bakery_prod.sum())
print("")
print('Café sales =', cafe_prod)
print('Total number of croissants sold =', cafe_prod.sum())
Bakery productions = [31. 48. 82. 30. 40. 48. 89. 73.]
Total number of croissants = 441.0

Café sales = [82. 88. 92. 88. 91.]
Total number of croissants sold = 441.0

We now normalize the weight vectors \(a\) and \(b\), i.e. the production and the sales, to deal with probability measures.

Code
bakery_prod = bakery_prod / bakery_prod.sum()
cafe_prod = cafe_prod / cafe_prod.sum()

Then, we plot the probability measures (the weighted point clouds) in \(\mathbb{R}^2\).

Code
fig, ax = plt.subplots(figsize=(10, 8))
plot_weighted_points(
    ax,
    x=bakery_pos,
    a=bakery_prod,
    x_label="Bakeries",
    y=cafe_pos,
    y_label="Cafés",
    b=cafe_prod,
    title="Bakeries and Cafés"
)
plt.show()

1.2.3 Solving the problem

To compute the optimal transport, we will consider three different costs:

  • \(\ell_1\): \(c(x, y) = \|x - y\|_1\) , (Manhattan distance)
  • \(\ell_2\): \(c(x, y) = \|x - y\|_2\), (Euclidean distance)
  • \(\ell_2^2\): \(c(x, y) = \|x - y\|_2^2\) (Squared-Euclidean distance)

Note that we expect different optimal transport plans for different costs.


Question:

  • Complete the following function that computes a cost matrix \(C\) from two set of points \(x, y\) and a cost function \(c\). Compute the three costs matrices \(C_{\ell_1}, C_{\ell_2}, C_{\ell_2^2}\in \mathbb{R}^{n \times m}\) using that function.
  • What cost should be used to minimize the total distance traveled by the driver that delivers croissants from Bakeries to Cafés?

Answer:

Code
bakery_pos
array([[184.86464733, 201.8163543 ],
       [449.3486663 , 168.40784664],
       [245.41756746, 288.12166576],
       [273.95400109, 364.68282915],
       [494.58935376, 336.8424061 ],
       [738.19305545, 238.70491485],
       [736.10502372, 375.12298779],
       [537.74200949, 482.30861653]])
Code
cafe_pos
array([[302.08410452, 442.78633642],
       [345.1162221 , 368.52123027],
       [449.226184  , 201.94529124],
       [454.08464888, 387.95508982],
       [627.60125204, 408.7770822 ]])
Code
def get_cost_matrix(
    x: np.ndarray,
    y: np.ndarray,
    cost_fn: Callable
) -> np.ndarray:
  """
  Compute the pairwise cost matrix between the n points in ``x`` and the m points in ``y``.
  It should output a matrix of size n x m.
  """
  return np.array([cost_fn(x_,y_) for x_ in x for y_ in y]).reshape(x.shape[0],y.shape[0])


# compute cost matrices for different costs
C_l1 = get_cost_matrix(
    x=bakery_pos, y=cafe_pos,
    cost_fn= lambda x,y : sum(np.abs(x-y))
  )

C_l2 = get_cost_matrix(
    x=bakery_pos, y=cafe_pos,
    cost_fn= lambda x,y : sum((x-y)**2)
)
C_l2_sq = get_cost_matrix(
    x=bakery_pos, y=cafe_pos,
    cost_fn= lambda x,y : sum(np.sqrt((x-y)**2))
)

# print shapes of cost matrices
print(
    f"Shape of C_l1: {C_l1.shape}\n"
    f"Shape of C_l2: {C_l2.shape}\n"
    f"Shape of C_l2_sq: {C_l2_sq.shape}"
)
Shape of C_l1: (8, 5)
Shape of C_l2: (8, 5)
Shape of C_l2_sq: (8, 5)

We can now compute the Optimal Transport plan to transport the croissants from the bakeries to the cafés, for the three different costs.


Question:

  • Complete the following fuction that takes as input the cost matrix \(C\) and the weights vectors \(a\) and \(b\) and outputs the optimal transport plan and the optimal transport cost using the ot.emd function. It has an option to display the results.
  • Use that function to compute and display the optiaml plan and the optimal cost for \(\ell_1, \ell_2\) and \(\ell_2^2\) geometries.

Remark: See https://pythonot.github.io/ for informations on the ot.emd function.

Answer:

Code
def compute_transport(
    C: np.ndarray,
    a: np.ndarray,
    b: np.ndarray,
    verbose: bool = False,
):
  """
  Compute the optimal transport plan and the optimal transport cost
  for cost matrix ``C`` and weight vectors $a$ and $b$.
  If ``verbose`` is set to True, it displays the results.
  """
  optimal_plan = ot.emd(a,b,C)
  optimal_cost = np.sum(optimal_plan * C)
  if verbose:
    print(
        f"optimal transport plan: \n{optimal_plan}"
    )
    print(
        f"transport cost: {optimal_cost}"
    )
  return optimal_plan, optimal_cost
Code
# l1 geometry
print("l1 geometry:")
optimal_plan_l1_croissant, optimal_cost_l1_croissant = compute_transport(
    C=C_l1,
    a=bakery_prod,
    b=cafe_prod,
    verbose=True
)
l1 geometry:
optimal transport plan: 
[[0.07029478 0.         0.         0.         0.        ]
 [0.         0.         0.10884354 0.         0.        ]
 [0.05442177 0.13151927 0.         0.         0.        ]
 [0.         0.06802721 0.         0.         0.        ]
 [0.         0.         0.         0.09070295 0.        ]
 [0.         0.         0.09977324 0.00453515 0.00453515]
 [0.         0.         0.         0.         0.20181406]
 [0.06122449 0.         0.         0.10430839 0.        ]]
transport cost: 177.28420815406028
Code
# l2 geometry
print("l2 geometry:")
optimal_plan_l2_croissant, optimal_cost_l2_croissant = compute_transport(
    C=C_l2,
    a=bakery_prod,
    b=cafe_prod,
    verbose=True
)
l2 geometry:
optimal transport plan: 
[[0.         0.07029478 0.         0.         0.        ]
 [0.         0.         0.10884354 0.         0.        ]
 [0.11791383 0.06802721 0.         0.         0.        ]
 [0.06802721 0.         0.         0.         0.        ]
 [0.         0.06122449 0.         0.02947846 0.        ]
 [0.         0.         0.09977324 0.00453515 0.00453515]
 [0.         0.         0.         0.         0.20181406]
 [0.         0.         0.         0.16553288 0.        ]]
transport cost: 24576.370543882178
Code
# squared l2 geometry
print("squared l2 geometry:")
optimal_plan_l2_sq_croissant, optimal_cost_l2_sq_croissant = compute_transport(
    C=C_l2_sq,
    a=bakery_prod,
    b=cafe_prod,
    verbose=True
)
squared l2 geometry:
optimal transport plan: 
[[0.07029478 0.         0.         0.         0.        ]
 [0.         0.         0.10884354 0.         0.        ]
 [0.05442177 0.13151927 0.         0.         0.        ]
 [0.         0.06802721 0.         0.         0.        ]
 [0.         0.         0.         0.09070295 0.        ]
 [0.         0.         0.09977324 0.00453515 0.00453515]
 [0.         0.         0.         0.         0.20181406]
 [0.06122449 0.         0.         0.10430839 0.        ]]
transport cost: 177.28420815406028

Now, we can visualize the assignement induced by each geometry.

Code
fig, ax = plt.subplots(
    1, 3, figsize=(9*3, 7)
)
plans = [optimal_plan_l1_croissant,
         optimal_plan_l2_croissant,
         optimal_plan_l2_sq_croissant]
titles = [r"$\ell_1$ geometry", r"$\ell_2$ geometry", r"$\ell_2^2$ geometry"]

for axes, plan, title in zip(ax, plans, titles):
  plot_assignement(
      ax=axes,
      x=bakery_pos, a=bakery_prod, x_label="Bakeries",
      y=cafe_pos, b=cafe_prod, y_label="Cafés",
      optimal_plan=plan,
      title=title
  )
plt.show()

1.2.4 In dimension \(d = 1\)

Let assume in this subsection that the cost is of the form \(c(x, y) = \|x - y\|_p^q\) with \(p, q \geq 1\), which covers the costs we considered in the previous examples, and that the points are in \(\mathbb{R}\), i.e. \(x_1, ..., x_n, y_1, ... , y_n \in \mathbb{R}\). Then, computing OT boils down to sorting the points. Indeed, for all costs of the above form, the optimal permutation between \(x\) and \(y\) is \(\sigma^* = \sigma_x^{-1} \circ \sigma_y\) where \(\sigma_x\) is the permutation sorting the \(x_i\) and \(\sigma_y\) the one sorting the \(y_i\). In particular, one has:

\[ W_c(\mu, \nu) = \frac{1}{n} \sum_{i=1}^n c(x_i, y_{\sigma_x^{-1} \circ \sigma_y(i)}) = \frac{1}{n} \sum_{i=1}^n c(x_{\sigma_x(i)}, y_{\sigma_y(i)}) \]

Thus, to compute the optimal transport cost, it is sufficient to sort \(x\) and \(y\).

Let’s check this fact on an example, by comparing the transport cost obtained by sorting the points to the one obtained with the function ot.emd. To simplify, we generate points \(x,y \subset \mathbb{R}\) s.t. \(x\) is sorted, i.e. \(\sigma_x = I_d\) and then \(\sigma^*=\sigma_y\). Therefore, computing the optimal assignement amounts to sort \(y\).

Code
# generate points
n = 5
x = np.arange(0, 2*n, 2) + .25 * np.random.normal(size=(n,))
a = np.ones(n) / n
y = np.arange(1, 2*n+1, 2) + .25 * np.random.normal(size=(n,))
np.random.shuffle(y)
b = np.ones(n) / n

# plot points
fig, ax = plt.subplots(figsize=(12, 6))
plot_points_1D(
    ax,
    x, y,
    title="1D points"
)


Question:

  • For \(\ell_1\) and \(\ell_2^2\) geometries (\(\ell_2\) and \(\ell_1\) coincides on \(\mathbb{R}\)), compute the optimal assignement and optimal transport cost by sorting \(y\). Put the assignement into a vector \(s \in \mathbb{R}^n\), s.t. \(x_i\) is mapped to \(y_{s_i}\), i.e. \(s_i = \sigma^*(i)\). Is it different according to the geometry?
  • Put now the assignment you obtained by sorting the points in the form of a transport plan \(P^* \in \mathbb{R}^{n \times n}\). Check that you obtain the results with ot.emd.

Answer:

Code
# sort the points
y_sorted = np.sort(y)

# get optimal assignment as a vector
assignment = np.argsort(y)

# transform it to a transport plan
optimal_plan = np.zeros((n,n))
for i, idx in enumerate(assignment):
    optimal_plan[i, idx] = 1 / n
print(
    f"optimal transport plan obtained by sorting the points:\n {optimal_plan}"
)

# The result doesn't match the lecturer's
optimal transport plan obtained by sorting the points:
 [[0.  0.  0.2 0.  0. ]
 [0.  0.  0.  0.  0.2]
 [0.2 0.  0.  0.  0. ]
 [0.  0.2 0.  0.  0. ]
 [0.  0.  0.  0.2 0. ]]
Code
# l1 geometry
print("l1 geometry:")
C_l1 = get_cost_matrix(
    x=x, y=y,
    cost_fn=lambda x,y: np.sum(np.abs(x - y))
)
optimal_plan_l1, optimal_cost_l1 = compute_transport(
    C=C_l1,
    a=a,
    b=b,
    verbose=True
)
print(
    f"is it equal to the one obtained by sorting the points? "
    f"{np.array_equal(optimal_plan_l1, optimal_plan)}"
)
l1 geometry:
optimal transport plan: 
[[0.  0.  0.2 0.  0. ]
 [0.  0.  0.  0.  0.2]
 [0.2 0.  0.  0.  0. ]
 [0.  0.2 0.  0.  0. ]
 [0.  0.  0.  0.2 0. ]]
transport cost: 1.150354389571881
is it equal to the one obtained by sorting the points? True
Code
# squared l2 geometry

def is_permutation(matrix):
    """
    Check if a given matrix is a permutation matrix.
    """
    n, m = matrix.shape
    if n != m:
        return False
    
    row_sum = np.sum(matrix, axis=1)
    col_sum = np.sum(matrix, axis=0)
    
    return np.all(row_sum == 1) and np.all(col_sum == 1) and np.all((matrix == 0) | (matrix == 1))

C_l2_sq = get_cost_matrix(
    x=x, y=y,
    cost_fn=lambda x,y: np.sum((x - y) ** 2)
)
optimal_plan_l2_sq, optimal_cost_l2_sq = compute_transport(
    C=C_l2_sq,
    a=a,
    b=b,
    verbose=True
)
print(
    f"is permutation matrix? {is_permutation(optimal_plan_l2_sq)}"
)
print(
    f"is it equal to the one obtained by sorting the points? "
    f"{np.array_equal(optimal_plan_l2_sq, optimal_plan)}"
)
optimal transport plan: 
[[0.  0.  0.2 0.  0. ]
 [0.  0.  0.  0.  0.2]
 [0.2 0.  0.  0.  0. ]
 [0.  0.2 0.  0.  0. ]
 [0.  0.  0.  0.2 0. ]]
transport cost: 1.5572261500204791
is permutation matrix? False
is it equal to the one obtained by sorting the points? True

Finally, one can plot the assignement.

Code
fig, ax = plt.subplots(figsize=(12, 6))
plot_assignement_1D(
    ax,
    x, y,
    title="1D assignement"
)
plt.show()

2 II. Entropy Regularized Optimal Transport

2.1 II.1 Reminders on Sinkhorn Algorithm

2.1.1 Adding negative entropy as a regularizer

In real ML applications, we often deal with large numbers of points. In this case, cubic complexity linear programming algorithms are too costly. This motivates (among other reasons) the regularized approach \[ \min_{P \in \mathcal{U}(a,b)} \langle C, P \rangle + \epsilon \sum_{ij} P_{ij} [ \log(P_{ij}) - 1]. \] For \(\epsilon\) is sufficiently small, one expects to recover an approximation of the original optimal transport plan.

2.1.2 The Sinkhorn iterates

In order to solve this problem, one can remark that the optimality conditions imply that a solution \(P_\epsilon^*\) necessarily is of the form \(P_\epsilon^* = \text{diag}(u) \, K \, \text{diag}(v)\), where \(K = \exp(-C/\epsilon)\) and \(u,v\) are two non-negative vectors.

\(P_\epsilon^*\) should verify the constraints, i.e. \(P_\epsilon^* \in U(a,b)\), so that \[ P_\epsilon^* 1_m = a \text{ and } (P_\epsilon^*)^T 1_n = b \] which can be rewritten as \[ u \odot (Kv) = a \text{ and } v \odot (K^T u) = b \]

Then Sinkhorn’s algorithm alternates between the resolution of these two equations, and reads at iteration \(t\): \[ u^{t+1} \leftarrow \frac{a}{Kv^t} \text{ and } v^{t+1} \leftarrow \frac{b}{K^T u^{t+1}} \]

2.1.3 Initialization and convergence

Usually, it starts from \(v^{0} = \mathrm{1}_m\) and alternate the above updates until \(\|u^{t+1} \odot (Kv^{t+1}) - a\|_1 + \|v^{t+1} \odot (K^T u^{t+1}) - b\|_1 \leq \tau\), where \(\tau > 0\) is a fixed convergence threshold. Actually, since at the end of each iteration, one exactly has \(v^{t+1} \odot (K^T u^{t+1}) = b\), it just remains to test if \(\|u^{t+1} \odot (Kv^{t+1}) - a\|_1 \leq \tau\).

From an entropic optimal transport plan \(P^*_\epsilon\), we can approximate the optimal transport cost by \(\sum_{i,j=1}^n P^*_{\epsilon_{ij}} C_{ij} = ⟨C, P^*_\epsilon⟩\). For the rest of the section, we call this quantity the entropic optimal transport cost.

2.2 II.2 Using your own Sinkhorn

2.2.1 Sinkhorn Implementation

In this section, you will implement your own version of the Sinkhorn Algorithm.


Question: Complete the following Sinkhorn algorithm, by:

  • Computing the kernel matrix \(K = \exp(-C / \epsilon)\),
  • Starting from \(v^{0} = \mathrm{1}_m\),
  • Alternating the updates \(u^{t+1} \odot (Kv^t) = a\) and \(v^{t+1} \odot (K^T u^{t+1}) = b\),
  • Declaring convergence when \(\|u^t \odot (Kv^t) - a\|_1 + \|v^t \odot (K^T u^t) - b\|_1 \leq \tau\).

Remark: you should also use also a maximum number of iterations max_iter, to stop the algorithm after a fixed number of iterations if the convergence is not reached.

Answer:

Code
def sinkhorn(
    a: np.ndarray,
    b: np.ndarray,
    C: np.ndarray,
    epsilon: float,
    max_iters: int = 100,
    tau: float = 1e-4
) -> np.ndarray:
    """
    Sinnkhorn's algorithm. It should output the optimal transport plan.
    """

    K = np.exp( -C / epsilon )
    n, m = a.shape[0], b.shape[0]
    v = np.ones((m,))
    for _ in range(max_iters):
        u = a / K.dot(v)
        v = b / K.transpose().dot(u)
    return u[:,None] * v[None,:] * K  # u_i, v_j, K_ij
Code
P = sinkhorn(a, b, C_l2_sq, epsilon=1)
print(P.sum(axis=0))
print(P.sum(axis=1))
[0.2 0.2 0.2 0.2 0.2]
[0.19804499 0.19825831 0.19923847 0.20184922 0.20260901]
Code
P = sinkhorn(a, b, C_l2_sq, epsilon=1, max_iters=1000)
print(P.sum(axis=0))
print(P.sum(axis=1))
[0.2 0.2 0.2 0.2 0.2]
[0.19994571 0.19994658 0.19995105 0.20007655 0.20008012]
Code
def sinkhorn(
    a: np.ndarray,
    b: np.ndarray,
    C: np.ndarray,
    epsilon: float,
    max_iters: int = 100,
    tau: float = 1e-4
) -> np.ndarray:
    """
    Sinnkhorn's algorithm. It should output the optimal transport plan.
    """

    K = np.exp( -C / epsilon )
    n, m = a.shape[0], b.shape[0]
    v = np.ones((m,))
    for i in range(max_iters):
        u = a / K.dot(v)
        v = b / K.transpose().dot(u)
        if i % 10 == 0:
            # compute row sum D(u) K D(v) = u * Kv
            if np.sum(np.abs(u * K.dot(v) - a)) < tau:
                print('early termination: ' + str(i))
                break
    return u[:,None] * v[None,:] * K  # u_i, v_j, K_ij
Code
P = sinkhorn(a, b, C_l2_sq, epsilon=1, max_iters=1000)
print(P.sum(axis=0))
print(P.sum(axis=1))
[0.2 0.2 0.2 0.2 0.2]
[0.19994571 0.19994658 0.19995105 0.20007655 0.20008012]
Code
P = sinkhorn(a, b, C_l2_sq, epsilon=0.1, max_iters=10000)
print(P.sum(axis=0))
print(P.sum(axis=1))
early termination: 9970
[0.2 0.2 0.2 0.2 0.2]
[0.19996656 0.19998346 0.20000018 0.20001672 0.20003307]

Now, we can test the Sinkhorn algorithm on the “croissant” transport example.


Question: * Complete the following fuction that takes as input the cost matrix \(C\) and the weights vectors \(a\) and \(b\) and outputs the entropic optimal transport plan and the entropic optimal transport cost using the sinkhorn function. As for the exact transport, it has an option to display the results. * Use that function on the croissant transport to compute and display the optimal plan and the optimal cost for the \(\ell_1, \ell_2\) and \(\ell_2^2\) geometries. * Each time you run the Sinkhorn algorithm, you should use \(\epsilon = 0.1 \cdot \bar{C}\), with \(\bar{C} = \frac{1}{nm} \sum_{i=1}^n \sum_{j=1}^m C_{ij}\) is the mean of the cost matrix. It remains to adapt the \(\epsilon\) value according to the cost matrix, to control the magnitude of the entries of \(C / \epsilon\). Why this strategy? What will happen if \(\epsilon\) is too small compared to the entries of \(C\)?

Answer:

Code
def compute_transport_sinkhorn(
    C: np.ndarray,
    a: np.ndarray,
    b: np.ndarray,
    epsilon: float,
    max_iters: int = 10_000,
    tau: float = 1e-4,
    verbose: bool = False,
):
  """
  Compute the entropic optimal transport plan and the entropic optimal transport cost
  for cost matrix ``C`` and weight vectors $a$ and $b$.
  If ``verbose`` is set to True, it displays the results.
  """
  optimal_plan_sinkhorn = sinkhorn(a, b, C, epsilon, max_iters, tau)
  optimal_cost_sinkhorn = np.sum(optimal_plan_sinkhorn * C)
  if verbose:
    print(
        f"entropic optimal transport plan: \n{optimal_plan_sinkhorn}"
    )
    print(
        f"entropic transport cost: {optimal_cost_sinkhorn}"
    )
  return optimal_plan_sinkhorn, optimal_cost_sinkhorn
Code
# l1 geometry
print("l1 geometry:")
C_l1 = get_cost_matrix(
    x=bakery_pos, y=cafe_pos,
    cost_fn=lambda x,y: np.sum(np.abs(x - y))
)
epsilon = 1
optimal_plan_sinkhorn_l1_croissant, optimal_cost_sinkhorn_l1_croissant = compute_transport_sinkhorn(
    C=C_l1,
    a=bakery_prod,
    b=cafe_prod,
    epsilon=epsilon,
    verbose=True,
)
l1 geometry:
early termination: 5970
entropic optimal transport plan: 
[[2.70428936e-002 4.32583290e-002 1.34214268e-047 1.83509051e-086
  1.62880160e-235]
 [3.42773260e-084 1.30691077e-046 1.08827539e-001 1.90100995e-040
  1.68731081e-189]
 [7.15328153e-002 1.14425257e-001 4.99347632e-122 4.85411038e-086
  4.30844293e-235]
 [2.61705422e-002 4.18628990e-002 5.77469196e-189 1.77589404e-086
  1.57625961e-235]
 [1.25908361e-049 4.80057848e-012 2.70172554e-084 9.07098043e-002
  1.22406524e-115]
 [2.66904088e-052 1.01764013e-014 9.97892408e-002 1.92289196e-004
  8.84651301e-003]
 [5.95876320e-051 4.18968529e-019 7.18872709e-119 4.29294956e-003
  1.97502693e-001]
 [6.11947921e-002 7.27950530e-029 1.24902883e-128 1.04351442e-001
  5.20381800e-060]]
entropic transport cost: 177.27648952346257
Code
plt.imshow(optimal_plan_sinkhorn_l1_croissant)

Code
# l2 geometry
print("l2 geometry:")
C_l2 = get_cost_matrix(
    x=bakery_pos, y=cafe_pos,
    cost_fn=lambda x,y: np.linalg.norm(x - y, ord=2)
)
epsilon = np.mean(C_l2_sq) * 0.05 # compute the optimal value to avoid underflow
optimal_plan_sinkhorn_l2_croissant, optimal_cost_sinkhorn_l2_croissant = compute_transport_sinkhorn(
    C=C_l2,
    a=bakery_prod,
    b=cafe_prod,
    epsilon=epsilon,
    verbose=True
)
l2 geometry:
early termination: 6260
entropic optimal transport plan: 
[[1.47475398e-002 5.55537055e-002 3.08774169e-064 9.30661064e-081
  4.69761757e-208]
 [6.15011232e-065 5.04201195e-042 1.08835700e-001 2.71365095e-074
  2.40874849e-164]
 [4.19665468e-002 1.43991570e-001 3.44511021e-091 1.29763971e-083
  1.42700380e-214]
 [6.80331674e-002 1.07065710e-012 2.39243900e-136 5.50823177e-098
  2.33302996e-231]
 [5.57193420e-021 1.20973940e-006 4.85598564e-049 9.07086080e-002
  1.03291636e-094]
 [4.80426784e-025 1.95723545e-010 9.97810800e-002 4.48552562e-003
  4.56931226e-003]
 [8.01821168e-045 5.16962071e-039 1.07457685e-060 1.67831560e-024
  2.01779894e-001]
 [6.11937890e-002 3.69083217e-011 2.43448067e-090 1.04352352e-001
  6.30899246e-051]]
entropic transport cost: 139.50260030035182
Code
# squared l2 geometry
print("squared l2 geometry:")
C_l2_sq = get_cost_matrix(
    x=bakery_pos, y=cafe_pos,
    cost_fn=lambda x,y: np.sum((x - y) ** 2)
)
epsilon = np.mean(C_l2_sq) * 0.05 # compute the optimal value to avoid underflow
optimal_plan_sinkhorn_l2_sq_croissant, optimal_cost_sinkhorn_l2_sq_croissant = compute_transport_sinkhorn(
    C=C_l2_sq,
    a=bakery_prod,
    b=cafe_prod,
    epsilon=epsilon,
    verbose=True
)
squared l2 geometry:
early termination: 390
entropic optimal transport plan: 
[[9.15185856e-03 6.11459020e-02 2.53416007e-06 9.36830019e-11
  5.88710931e-33]
 [1.98312263e-09 2.86298466e-05 1.08801507e-01 2.62245737e-07
  1.22762517e-18]
 [1.02592189e-01 8.33635462e-02 3.95058637e-08 1.25018168e-08
  7.16059834e-28]
 [6.36539082e-02 4.37874407e-03 9.20213996e-12 8.38044497e-09
  1.80226054e-26]
 [1.15178824e-03 4.78786088e-02 4.40333811e-04 4.12385828e-02
  1.04144097e-10]
 [1.27912352e-15 1.00116432e-09 9.93719135e-02 7.28353201e-04
  8.73056135e-03]
 [5.51097456e-13 1.50024181e-09 4.51186568e-07 4.17827621e-03
  1.97618514e-01]
 [9.39129734e-03 2.75105187e-03 4.51869851e-10 1.53400990e-01
  1.31200779e-07]]
entropic transport cost: 24883.330683517215

2.2.2 The effect of \(\epsilon\)

Now we can display the transportation plans obtained with Sinkhorn’s algortihm, as we did for the exact OT.

Code
fig, ax = plt.subplots(
    1, 3, figsize=(9*3, 7)
)
plans = [optimal_plan_sinkhorn_l1_croissant,
         optimal_plan_sinkhorn_l2_croissant,
         optimal_plan_sinkhorn_l2_sq_croissant]

titles = [r"$\ell_1$ geometry", r"$\ell_2$ geometry", r"$\ell_2^2$ geometry"]

for axes, plan, title in zip(ax, plans, titles):
  plot_assignement(
      ax=axes,
      x=bakery_pos, a=bakery_prod, x_label="Bakeries",
      y=cafe_pos, b=cafe_prod, y_label="Cafés",
      optimal_plan=plan,
      title=title
  )
plt.show()

Note: There always is some transport at every edge in Sinkhorn algorithm’s output.

Code
fig, ax = plt.subplots(
    1, 3, figsize=(9*3, 7)
)
plans = [optimal_plan_l1_croissant,
         optimal_plan_l2_croissant,
         optimal_plan_l2_sq_croissant]

titles = [r"$\ell_1$ geometry", r"$\ell_2$ geometry", r"$\ell_2^2$ geometry"]

for axes, plan, title in zip(ax, plans, titles):
  plot_assignement(
      ax=axes,
      x=bakery_pos, a=bakery_prod, x_label="Bakeries",
      y=cafe_pos, b=cafe_prod, y_label="Cafés",
      optimal_plan=plan,
      title=title
  )
plt.show()

The above transport plans are obtained for \(\epsilon = 0.1 \cdot \bar{C}\). Let’s increase epsilon to \(\epsilon = 10 \cdot \bar{C}\) and replot the optimal transport plans to visualize the effect of epsilon.

Code
# l1 geometry
epsilon = 10 * np.mean(C_l1)
optimal_plan_sinkhorn_l1_croissant, optimal_cost_sinkhorn_l1_croissant = compute_transport_sinkhorn(
    C=C_l1,
    a=bakery_prod,
    b=cafe_prod,
    epsilon=epsilon,
    verbose=False,
)

# l2 geometry
epsilon = 10 * np.mean(C_l2)
optimal_plan_sinkhorn_l2_croissant, optimal_cost_sinkhorn_l2_croissant = compute_transport_sinkhorn(
    C=C_l2,
    a=bakery_prod,
    b=cafe_prod,
    epsilon=epsilon,
    verbose=False
)

# squared l2 geometry
epsilon = 10 * np.mean(C_l2_sq)
optimal_plan_sinkhorn_l2_sq_croissant, optimal_cost_sinkhorn_l2_sq_croissant = compute_transport_sinkhorn(
    C=C_l2_sq,
    a=bakery_prod,
    b=cafe_prod,
    epsilon=epsilon,
    verbose=False
)

fig, ax = plt.subplots(
    1, 3, figsize=(9*3, 7)
)
plans = [optimal_plan_l1_croissant,
         optimal_plan_l2_croissant,
         optimal_plan_l2_sq_croissant]
titles = [r"$\ell_1$ geometry", r"$\ell_2$ geometry", r"$\ell_2^2$ geometry"]

for axes, plan, title in zip(ax, plans, titles):
  plot_assignement(
      ax=axes,
      x=bakery_pos, a=bakery_prod, x_label="Bakeries",
      y=cafe_pos, b=cafe_prod, y_label="Cafés",
      optimal_plan=plan,
      title=title
  )
plt.show()
early termination: 10
early termination: 10
early termination: 10

Note: If the epsilon is large, the distribution is close to uniform.


Question: What do you observe in relation to the transport plans obtained for the exact optimal transport?

Answer:


2.2.3 Sinkhorn consistency

We now show that this Sinkhorn algorithm is consistent with classical optimal transport, using the “croissant” transport example and focusing on the \(\ell_2\) cost.


Question: Complete the following code to compute, for various \(\epsilon'\), values on a regular grid: * Set \(\epsilon = \epsilon' \cdot \bar{C}\), * The deviation of the entropic optimal plan \(P^*_\epsilon\) to the exact optimal plan \(P^*\), namely \(\|P^*_\epsilon - P^*\|_2\). * The deviation of the entropic optimal cost \(\langle C, P^*_\epsilon \rangle\) to the exact optimal plan \(\langle C, P^*_\epsilon \rangle\), namely: \(\langle C, P^*_\epsilon \rangle - \langle C, P^* \rangle\).

We remind that the excat optimal transport plan for the \(\ell_2\) cost is stored as variable optimal_plan_l2_croissant.

Answer:

Code
plan_diff = []
distance_diff = []
grid = np.linspace(0.01, 5, 100)
for epsilon_prime in grid:
  epsilon = epsilon_prime * np.mean(C_l2)
  optimal_plan_sinkhorn_l2_croissant, optimal_cost_sinkhorn_l2_croissant = compute_transport_sinkhorn(
      C=C_l2,
      a=bakery_prod,
      b=cafe_prod,
      epsilon=epsilon,
      verbose=False
  )
  assert optimal_cost_sinkhorn_l2_croissant != np.nan, (
      "Optimal cost is nan due to numerical instabilities."
  )
  plan_diff.append(
      np.sum(np.abs(optimal_plan_sinkhorn_l2_croissant - optimal_plan_l2_croissant))
  )
  distance_diff.append(
      optimal_cost_sinkhorn_l2_croissant - optimal_cost_l2_croissant
  )
early termination: 2460
early termination: 220
early termination: 50
early termination: 30
early termination: 20
early termination: 20
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10

Now, let’s plot the results.

Code
fig, ax = plt.subplots(2, 1, figsize=(16, 5*2))
reg_strengths = np.mean(C_l2) * grid
plot_consistency(
    ax,
    reg_strengths,
    plan_diff,
    distance_diff
)

plt.show()
/Users/hirofumi48/Library/Python/3.9/lib/python/site-packages/IPython/core/pylabtools.py:152: UserWarning: Data has no positive values, and therefore cannot be log-scaled.
  fig.canvas.print_figure(bytes_io, **kw)

Note: The result is different from the lecturer’s.

2.3 II.3 Using OTT

2.3.1 Install OTT

First, you need to install OTT.

Code
%pip install ott-jax
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: ott-jax in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (0.4.5)
Requirement already satisfied: jax>=0.4.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from ott-jax) (0.4.30)
Requirement already satisfied: jaxopt>=0.8 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from ott-jax) (0.8.3)
Requirement already satisfied: numpy>=1.20.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from ott-jax) (1.26.1)
Requirement already satisfied: lineax>=0.0.1 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from ott-jax) (0.0.4)
Requirement already satisfied: jaxlib<=0.4.30,>=0.4.27 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (0.4.30)
Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (0.3.2)
Requirement already satisfied: opt-einsum in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (3.3.0)
Requirement already satisfied: scipy>=1.9 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (1.11.3)
Requirement already satisfied: importlib-metadata>=4.6 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (6.8.0)
Requirement already satisfied: equinox>=0.11.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from lineax>=0.0.1->ott-jax) (0.11.3)
Requirement already satisfied: jaxtyping>=0.2.20 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from lineax>=0.0.1->ott-jax) (0.2.28)
Requirement already satisfied: typing-extensions>=4.5.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from lineax>=0.0.1->ott-jax) (4.8.0)
Requirement already satisfied: zipp>=0.5 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from importlib-metadata>=4.6->jax>=0.4.0->ott-jax) (3.17.0)
Requirement already satisfied: typeguard==2.13.3 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jaxtyping>=0.2.20->lineax>=0.0.1->ott-jax) (2.13.3)
Note: you may need to restart the kernel to use updated packages.

Then we load the required pakages.

Code
import jax
import jax.numpy as jnp
import jax.random as random
import ott
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

2.3.2 A world about OTT and JAX

OTT is a python library that allows to compute and differentiate the entropic optimal transport. In this lab session, we will focus on entropic optimal transport computation, and not differentiation. differentiation will be takcled later.

OTT is based on JAX, a package similar to PyTorch or TensorFlow, which allows to do automatic differentiation and GPU programming. It also provides useful primitives for efficient computation, such as the just-in-time (jit) compilation or the automatic vectorization map vmap. For more informations on JAX, see the tutorial https://jax.readthedocs.io/en/latest/notebooks/quickstart.html.

Unlike PyTorch or TensorFlow, JAX is very close to numpy thanks to the jax.numpy package, which implements most of the numpy features, but for the JAX data structures. For this lab session, you only need to know how to manipulate jax.numpy Arrays and generate random numbers with jax.random.

First, let’s have a look to jax.numpy and see that it works (almost) exactly as numpy. Usually, one imports jax.numpy as jnp as done in the above cells, and developp as with numpy, by just replacing np by jnp. Note that jax.numpy Arrays are called DeviceArray. For more informations on jax.numpy, see https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html.

Code
d = 5
u = 5 * jnp.ones(5)
Id = jnp.eye(5)
print(type(u))
print(f"u = {u}")
print(f"Id = {Id}")
print(f"Id @ u = {jnp.dot(Id, u)}")
print(f"sum(u) = {jnp.sum(u)}")
print(f"var(u) = {jnp.var(u)}")
<class 'jaxlib.xla_extension.ArrayImpl'>
u = [5. 5. 5. 5. 5.]
Id = [[1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]]
Id @ u = [5. 5. 5. 5. 5.]
sum(u) = 25.0
var(u) = 0.0

With numpy.random, you can generate random numbers on the fly without giving the seed. For example, np.random.rand() generates a random number \(x \sim U([0, 1])\). Indeed, numpy.random uses an internal seed which is updated each time a random number generating function is called. On the other hand, with jax.random, we must give the seed each time we generate random numbers. To some extent, we want to always control the randomness. Moreover, we do not pass exactly a seed but a jax.random.PRNGKey key which is itself instantiated from a seed. Let’s see it on an example.

Code
rng = jax.random.PRNGKey(0)
n, d = 13, 2
x = jax.random.normal(rng, (n, d))
print(f"x = {x}")
x = [[ 2.516351   -1.3947194 ]
 [-0.8633262   0.6413567 ]
 [-0.37789643 -0.6044598 ]
 [ 1.9069     -0.17918469]
 [-0.7583423  -0.5160155 ]
 [ 1.2666148  -0.12342127]
 [ 0.28430256 -0.17251171]
 [ 1.0661486   1.5814103 ]
 [-2.0284636  -0.13168257]
 [-0.14515765  0.21532312]
 [-0.69525063 -0.9314128 ]
 [-0.89809936 -0.25272107]
 [-0.34937173  1.8394127 ]]

Then, to have new keys to generate new random numbers, we need to split the key via jax.random.split, which generate \(n \geq 2\) new keys from a key.

Code
rng1, rng2, rng3 = jax.random.split(rng, 3)
a = jax.random.normal(rng1, (n, d))
b = jax.random.normal(rng2, (n, d))
c = jax.random.normal(rng2, (n, d))
print(f"a = {a}")
print(f"b = {b}")
print(f"c = {c}")
a = [[-0.38696066 -0.96707183]
 [ 1.0078175  -0.6096286 ]
 [-1.153353    1.0749092 ]
 [-1.2452031  -0.63885343]
 [ 0.01121208  0.2842425 ]
 [ 0.5296049   0.26609063]
 [ 0.8728492   1.0844501 ]
 [ 1.4472795  -0.82503337]
 [-0.41826957  0.21321987]
 [ 1.9602116   0.17687395]
 [-0.9978761  -2.0551765 ]
 [-0.4094941  -1.4577458 ]
 [-1.0969195  -0.66684234]]
b = [[ 0.10911155 -0.45371595]
 [ 0.12062439 -0.06927001]
 [ 0.00600028  2.3732579 ]
 [-0.17656058  1.7653493 ]
 [-0.06429235  0.487175  ]
 [-1.1079016  -1.0277865 ]
 [-0.0553451  -0.28271845]
 [-0.9633478  -0.05370665]
 [ 0.20281292 -0.16658288]
 [ 0.8015828  -0.61697495]
 [-0.30176872 -1.1862007 ]
 [-3.106658   -0.03262986]
 [ 0.53711027  0.21359496]]
c = [[ 0.10911155 -0.45371595]
 [ 0.12062439 -0.06927001]
 [ 0.00600028  2.3732579 ]
 [-0.17656058  1.7653493 ]
 [-0.06429235  0.487175  ]
 [-1.1079016  -1.0277865 ]
 [-0.0553451  -0.28271845]
 [-0.9633478  -0.05370665]
 [ 0.20281292 -0.16658288]
 [ 0.8015828  -0.61697495]
 [-0.30176872 -1.1862007 ]
 [-3.106658   -0.03262986]
 [ 0.53711027  0.21359496]]

you now know everything you need for the moment!

2.3.3 Entropic optimal transport with OTT

Now let’s use the implementation of the OTT Sinkhorn algorithm, on some random weighted point clouds. Then you will, by yourself, use it on the “croissant” transport example.

Let’s first generate the data.

Code
# generate data
rng = jax.random.PRNGKey(0)
rng1, rng2 = jax.random.split(rng, 2)
n, m, d = 13, 17, 2
x = jax.random.normal(rng1, (n, d))
y = jax.random.normal(rng2, (m, d)) + 1
a = jnp.ones(n) / n
b = jnp.ones(m) / m

Then, we have to define a PointCloud geometry which contains: * the point clouds x and y, * the cost function cost_fn, * the entropic regularization strength epsilon.

Note that the geometry does not contain the weight vectors a and b, these are passed later.

The cost_fn should be an istance of ott.geometry.CostFn. Most of the usual costs are implemented. For example, the three costs \(\ell_1, \ell_2\) and \(\ell_2^2\) are implemented. Here, we will focus on the \(\ell_2\) cost, implemented by ott.geometry.costs.Euclidean. See https://ott-jax.readthedocs.io/en/latest/_autosummary/ott.geometry.costs.CostFn.html#ott.geometry.costs.CostFn for more information on the provided cost_fn.

We still choose epsilon to be \(0.1 \cdot \bar{C}\). To do this, we set relative_epsilon=True when instantiating the geometry. The term relative means that epsilon is chosen relatively to the mean of the cost matrix. Passing then epsilon=0.1, the value of epsilon used by Sinkhorn will be \(0.1 \cdot \bar{C}\).

Code
# define geometry
geom = pointcloud.PointCloud(
    x=x, y=y,
    cost_fn=costs.Euclidean(),
    epsilon=1e-1,
    relative_epsilon=True
)

We then define an optimization problem from this geometry, which is the problem we will solve with the Sinkhorn algorithm. We instantiate this optimization problem as an object of the class linear_problem.LinearProblem. We pass the weight vectors a and b because they define the constraints of the linear problem. Then, we instantiate a Sinkhorn solver, object of the class sinkhorn.Sinkhorn, which we will use to solve this optimization problem.

The OTT library is designed in this way because it allows to solve other optimal transport problems, which do not necessarily have a linear problem structure, and which use other solvers than Sinkhorn.

Code
# create optimization problem
ot_prob = linear_problem.LinearProblem(geom, a=a, b=b)

# create sinkhorn solver
solver = sinkhorn.Sinkhorn(ot_prob)

# solve the OT problem
ot_sol = solver(ot_prob)

The ot output object contains several callables and properties, notably a boolean assessing the Sinkhorn convergence, the marginal errors throughtout iterations and the optimal transport plan.

Code
print(
    " Sinkhorn has converged: ",
    ot_sol.converged,
    "\n",
    "Error upon last iteration: ",
    ot_sol.errors[(ot_sol.errors > -1)][-1],
    "\n",
    "Sinkhorn required ",
    jnp.sum(ot_sol.errors > -1),
    " iterations to converge. \n",
    "entropic OT cost: ",
    jnp.sum(ot_sol.matrix * ot_sol.geom.cost_matrix),
)
 Sinkhorn has converged:  True 
 Error upon last iteration:  0.00019063428 
 Sinkhorn required  5  iterations to converge. 
 entropic OT cost:  29.436863

Question: Compute the entropic optimal transport plan and cost for the “croissant” transport problem, with \(\ell_2\) cost and \(\epsilon = 0.1 \cdot \bar{C}\). Then, plot the optimal transport plan.

Answer: