The goal of this practical session is to introduce computational optimal transport (OT) in Python. you will familiarize yourself with OT by: 1. Computing “exact” unregularized optimal transport, using the Python library POT (Python Optimal Transport). 2. Computing entropic optimal transport, using first your own version of the Sinkhorn algorithm, then the Python library OTT-JAX.
In order to lighten the reading of the notebook, we place the functions allowing to perform plots, that are used multiple times, in the section below.
Code
def plot_weighted_points( ax, x, a, y, b, title=None, x_label=None, y_label=None): ax.scatter(x[:,0], x[:,1], s=5000*a, c='r', edgecolors='k', label=x_label) ax.scatter(y[:,0], y[:,1], s=5000*b, c='b', edgecolors='k', label=y_label)for i inrange(np.shape(x)[0]): ax.annotate(str(i+1), (x[i,0], x[i,1]),fontsize=30,color='black')for i inrange(np.shape(y)[0]): ax.annotate(str(i+1), (y[i,0], y[i,1]),fontsize=30,color='black')if x_label isnotNoneor y_label isnotNone: ax.legend(fontsize=20) ax.axis('off') ax.set_title(title, fontsize=25)def plot_assignement( ax, x, a, y, b, optimal_plan, title=None, x_label=None, y_label=None): plot_weighted_points( ax=ax, x=x, a=a, y=y, b=b, title=None, x_label=x_label, y_label=y_label )for i inrange(optimal_plan.shape[0]):for j inrange(optimal_plan.shape[1]): ax.plot([x[i,0], y[j,0]], [x[i,1], y[j,1]], c='k', lw=30*optimal_plan[i,j], alpha=0.8) ax.axis('off') ax.set_title(title, fontsize=30)def plot_assignement_1D( ax, x, y, title=None): plot_points_1D( ax, x, y, title=None ) x_sorted = np.sort(x) y_sorted = np.sort(y)assertlen(x) ==len(y), "x and y must have the same shape."for i inrange(len(x)): ax.hlines( y=0, xmin=min(x_sorted[i], y_sorted[i]), xmax=max(x_sorted[i], y_sorted[i]), color='k', lw=10 ) ax.axis('off') ax.set_title(title, fontsize=30)def plot_points_1D( ax, x, y, title=None): n =len(x) a = np.ones(n) / n ax.scatter(x, np.zeros(n), s=1000*a, c='r') ax.scatter(y, np.zeros(n), s=1000*b, c='b') min_val =min(np.min(x), np.min(y)) max_val =max(np.max(x), np.max(y))for i inrange(n): ax.annotate(str(i+1), xy=(x[i], 0.005), size=30, color='r', ha='center')for j inrange(n): ax.annotate(str(j+1), xy=(y[j], 0.005), size=30, color='b', ha='center') ax.axis('off') ax.plot(np.linspace(min_val, max_val, 10), np.zeros(10)) ax.set_title(title, fontsize=30)def plot_consistency( ax, reg_strengths, plan_diff, distance_diff): ax[0].loglog(reg_strengths, plan_diff, lw=4) ax[0].set_ylabel('$||P^* - P_\epsilon^*||_F$', fontsize=25) ax[1].tick_params(which='both', size=20) ax[0].grid(ls='--') ax[1].loglog(reg_strengths, distance_diff, lw=4) ax[1].set_xlabel('Regularization Strength $\epsilon$', fontsize=25) ax[1].set_ylabel(r'$ 100 \cdot \frac{\langle C, P^*_\epsilon \rangle - \langle C, P^* \rangle}{\langle C, P^* \rangle} $', fontsize=25) ax[1].tick_params(which='both', size=20) ax[1].grid(ls='--')
1 I: Exact Optimal Transport with POT
1.1 I.1 Reminders on Discrete Optimal Transport
Optimal Transport is a theory that allows us to compare two (weighted) points clouds \((x, a)\) and \((y, b)\), where \(x \in \mathbb{R}^{n \times d}\) and \(y \in \mathbb{R}^{m \times d}\) are the locations of the \(n\) (resp. \(m\)) points in dimension \(d\), and \(a \in \mathbb{R}^n\), \(b \in \mathbb{R}^m\) are the weights. We ask that the total weights sum to one, i.e. \(\sum_{i=1}^n a_i = \sum_{j=1}^m b_j = 1\).
The basic idea of Optimal Transport is to “transport” the mass located at points \(x\) to the mass located at points \(y\).
Let us denote by \(U(a,b) := \left\{ P \in \mathbb{R}^{n \times m} \,|\, P \geq 0, \sum_{j=1}^m P_{ij} = a_i, \sum_{i=1}^n P_{ij} = b_j\right\}\) the set of admissible transport plans.
If \(P \in U(a,b)\), the quantity \(P_{ij} \geq 0\) should be regarded as the mass transported from point \(x_i\) to point \(y_j\). For this reason, it is called a transport plan.
We will also consider a cost function\(c : \mathbb{R}^d \times \mathbb{R}^d → \mathbb{R}\) and the associated cost matrix \(C = [c(x_i, y_j)]_{1\leq i,j \leq n,m}\in \mathbb{R}^{n \times m}\), containing the pairwise costs between the points of each point cloud \(x\) and \(y\). The quantity \(C_{ij}\) should be regarded as the cost paid for transporting one unit of mass from \(x_i\) to \(y_j\). This cost is usually computed using the positions \(x_i\) and \(y_j\), for example \(C_{ij} = \|x_i - y_j\|_2\) or \(C_{ij} = \|x_i - y_j\|_2^2\), but may be more exotic in some cases.
Then transporting mass according to \(P \in U(a,b)\) has a total cost of \(\sum_{i,j=1}^n P_{ij} C_{ij}\).
In “Optimal Transport”, there is the word Optimal. Indeed, we want to find a transport plan \(P \in U(a,b)\) that will minimize its total cost. In other words, we want to solve \[
\min_{P \in U(a,b)} \sum_{i,j=1}^n C_{ij }P_{ij} = \min_{P \in U(a,b)} ⟨C, P⟩.
\]
This problem is a Linear Program: the objective function is linear in the variable \(P\), and the constraints are linear in \(P\). We can thus solve this problem using classical Linear Programming algorithms, such as the simplex algorithm.
If \(P^*\) is a solution to the Optimal Transport problem, we will say that \(P^*\) is an optimal transport plan between \((x, a)\) and \((y, b)\), and that \(\sum_{ij} P^*_{ij} C_{ij}\) is the optimal transport distance between \((x, a)\) and \((y, b)\): it is the minimal amount of “energy” that is necessary to transport the initial mass located at points \(x\) to the target mass lcoated at points \(y\).
Usually, we represent the weighted point clouds by probability measures \(\mu = \sum_{i=1}^n a_i \delta_{x_i}\) and \(\nu = \sum_{j=1}^m b_j \delta_{y_j}\). Solving the above problem, we then say that we solve the optimal transport problem between the measures \(\mu\) and \(\nu\). Moreover, we note: \[
W_c(\mu, \nu) = \min_{P \in U(a,b)} ⟨C, P⟩.
\]
1.2 I.2 Computing Optimal “Croissant” Transport
1.2.1 Install
First, you need to install a few packages:
Code
%pip install POT%pip install cloudpickle
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: POT in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (0.9.3)
Requirement already satisfied: numpy>=1.16 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from POT) (1.26.1)
Requirement already satisfied: scipy>=1.6 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from POT) (1.11.3)
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: cloudpickle in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (3.0.0)
Note: you may need to restart the kernel to use updated packages.
Then, load the required packages.
Code
import otimport numpy as npimport osfrom typing import Callableimport matplotlib.pyplot as plt
Finally, connect the notebok to your drive to load some data that will be used for the experiments.
1.2.2 Formalization of the problem
We will solve the Bakeries/Cafés problem of transporting croissants from a number of Bakeries to Cafés.
We use fictional positions, production and sale numbers. We impose that the total croissant production is equal to the number of croissants sold, so that Bakeries and Cafés can be represented as measures with the same total mass. Then, up to normalization, they can be processed as probability measures.
Mathematically, we have acess to the position of the \(m\) Bakeries as points in \(\mathbb{R}^2\) via \(x \in \mathbb{R}^{n \times 2}\) and their respective production via \(a \in \mathbb{R}^m\) which describe the source point cloud. The Cafés where the croissants are sold are also defined by their position \(y \in \mathbb{R}^{m \times 2}\) and the quantity of croissants sold by \(b \in \mathbb{R}^{m}\).
Afterwards, the Bakeries are represented by the probability measure \(\mu = \sum_{i=1}^n a_i \delta_{x_i}\) and the Cafés by \(\nu = \sum_{j=1}^n b_j \delta_{y_j}\). Calculating the optimal assignment of the croissants delivered by the Bakeries to the Cafés remains to calculating the optimal transport between the probability measures \(\mu\) and \(\nu\).
Let’s download the data and check that the total croissant production is equal to the number of croissants sold.
Code
# Load the dataimport picklefrom urllib.request import urlopenimport cloudpickle as cpcroissants = cp.load(urlopen('https://marcocuturi.net/data/croissants.pickle'))bakery_pos = croissants['bakery_pos']bakery_prod = croissants['bakery_prod']cafe_pos = croissants['cafe_pos']cafe_prod = croissants['cafe_prod']print('Bakery productions =', bakery_prod)print('Total number of croissants =', bakery_prod.sum())print("")print('Café sales =', cafe_prod)print('Total number of croissants sold =', cafe_prod.sum())
Bakery productions = [31. 48. 82. 30. 40. 48. 89. 73.]
Total number of croissants = 441.0
Café sales = [82. 88. 92. 88. 91.]
Total number of croissants sold = 441.0
We now normalize the weight vectors \(a\) and \(b\), i.e. the production and the sales, to deal with probability measures.
Note that we expect different optimal transport plans for different costs.
Question:
Complete the following function that computes a cost matrix \(C\) from two set of points \(x, y\) and a cost function \(c\). Compute the three costs matrices \(C_{\ell_1}, C_{\ell_2}, C_{\ell_2^2}\in \mathbb{R}^{n \times m}\) using that function.
What cost should be used to minimize the total distance traveled by the driver that delivers croissants from Bakeries to Cafés?
def get_cost_matrix( x: np.ndarray, y: np.ndarray, cost_fn: Callable) -> np.ndarray:""" Compute the pairwise cost matrix between the n points in ``x`` and the m points in ``y``. It should output a matrix of size n x m. """return np.array([cost_fn(x_,y_) for x_ in x for y_ in y]).reshape(x.shape[0],y.shape[0])# compute cost matrices for different costsC_l1 = get_cost_matrix( x=bakery_pos, y=cafe_pos, cost_fn=lambda x,y : sum(np.abs(x-y)) )C_l2 = get_cost_matrix( x=bakery_pos, y=cafe_pos, cost_fn=lambda x,y : sum((x-y)**2))C_l2_sq = get_cost_matrix( x=bakery_pos, y=cafe_pos, cost_fn=lambda x,y : sum(np.sqrt((x-y)**2)))# print shapes of cost matricesprint(f"Shape of C_l1: {C_l1.shape}\n"f"Shape of C_l2: {C_l2.shape}\n"f"Shape of C_l2_sq: {C_l2_sq.shape}")
Shape of C_l1: (8, 5)
Shape of C_l2: (8, 5)
Shape of C_l2_sq: (8, 5)
We can now compute the Optimal Transport plan to transport the croissants from the bakeries to the cafés, for the three different costs.
Question:
Complete the following fuction that takes as input the cost matrix \(C\) and the weights vectors \(a\) and \(b\) and outputs the optimal transport plan and the optimal transport cost using the ot.emd function. It has an option to display the results.
Use that function to compute and display the optiaml plan and the optimal cost for \(\ell_1, \ell_2\) and \(\ell_2^2\) geometries.
Remark: See https://pythonot.github.io/ for informations on the ot.emd function.
Answer:
Code
def compute_transport( C: np.ndarray, a: np.ndarray, b: np.ndarray, verbose: bool=False,):""" Compute the optimal transport plan and the optimal transport cost for cost matrix ``C`` and weight vectors $a$ and $b$. If ``verbose`` is set to True, it displays the results. """ optimal_plan = ot.emd(a,b,C) optimal_cost = np.sum(optimal_plan * C)if verbose:print(f"optimal transport plan: \n{optimal_plan}" )print(f"transport cost: {optimal_cost}" )return optimal_plan, optimal_cost
Let assume in this subsection that the cost is of the form \(c(x, y) = \|x - y\|_p^q\) with \(p, q \geq 1\), which covers the costs we considered in the previous examples, and that the points are in \(\mathbb{R}\), i.e. \(x_1, ..., x_n, y_1, ... , y_n \in \mathbb{R}\). Then, computing OT boils down to sorting the points. Indeed, for all costs of the above form, the optimal permutation between \(x\) and \(y\) is \(\sigma^* = \sigma_x^{-1} \circ \sigma_y\) where \(\sigma_x\) is the permutation sorting the \(x_i\) and \(\sigma_y\) the one sorting the \(y_i\). In particular, one has:
Thus, to compute the optimal transport cost, it is sufficient to sort \(x\) and \(y\).
Let’s check this fact on an example, by comparing the transport cost obtained by sorting the points to the one obtained with the function ot.emd. To simplify, we generate points \(x,y \subset \mathbb{R}\) s.t. \(x\) is sorted, i.e. \(\sigma_x = I_d\) and then \(\sigma^*=\sigma_y\). Therefore, computing the optimal assignement amounts to sort \(y\).
For \(\ell_1\) and \(\ell_2^2\) geometries (\(\ell_2\) and \(\ell_1\) coincides on \(\mathbb{R}\)), compute the optimal assignement and optimal transport cost by sorting \(y\). Put the assignement into a vector \(s \in \mathbb{R}^n\), s.t. \(x_i\) is mapped to \(y_{s_i}\), i.e. \(s_i = \sigma^*(i)\). Is it different according to the geometry?
Put now the assignment you obtained by sorting the points in the form of a transport plan \(P^* \in \mathbb{R}^{n \times n}\). Check that you obtain the results with ot.emd.
Answer:
Code
# sort the pointsy_sorted = np.sort(y)# get optimal assignment as a vectorassignment = np.argsort(y)# transform it to a transport planoptimal_plan = np.zeros((n,n))for i, idx inenumerate(assignment): optimal_plan[i, idx] =1/ nprint(f"optimal transport plan obtained by sorting the points:\n{optimal_plan}")# The result doesn't match the lecturer's
optimal transport plan obtained by sorting the points:
[[0. 0. 0.2 0. 0. ]
[0. 0. 0. 0. 0.2]
[0.2 0. 0. 0. 0. ]
[0. 0.2 0. 0. 0. ]
[0. 0. 0. 0.2 0. ]]
Code
# l1 geometryprint("l1 geometry:")C_l1 = get_cost_matrix( x=x, y=y, cost_fn=lambda x,y: np.sum(np.abs(x - y)))optimal_plan_l1, optimal_cost_l1 = compute_transport( C=C_l1, a=a, b=b, verbose=True)print(f"is it equal to the one obtained by sorting the points? "f"{np.array_equal(optimal_plan_l1, optimal_plan)}")
l1 geometry:
optimal transport plan:
[[0. 0. 0.2 0. 0. ]
[0. 0. 0. 0. 0.2]
[0.2 0. 0. 0. 0. ]
[0. 0.2 0. 0. 0. ]
[0. 0. 0. 0.2 0. ]]
transport cost: 1.150354389571881
is it equal to the one obtained by sorting the points? True
Code
# squared l2 geometrydef is_permutation(matrix):""" Check if a given matrix is a permutation matrix. """ n, m = matrix.shapeif n != m:returnFalse row_sum = np.sum(matrix, axis=1) col_sum = np.sum(matrix, axis=0)return np.all(row_sum ==1) and np.all(col_sum ==1) and np.all((matrix ==0) | (matrix ==1))C_l2_sq = get_cost_matrix( x=x, y=y, cost_fn=lambda x,y: np.sum((x - y) **2))optimal_plan_l2_sq, optimal_cost_l2_sq = compute_transport( C=C_l2_sq, a=a, b=b, verbose=True)print(f"is permutation matrix? {is_permutation(optimal_plan_l2_sq)}")print(f"is it equal to the one obtained by sorting the points? "f"{np.array_equal(optimal_plan_l2_sq, optimal_plan)}")
optimal transport plan:
[[0. 0. 0.2 0. 0. ]
[0. 0. 0. 0. 0.2]
[0.2 0. 0. 0. 0. ]
[0. 0.2 0. 0. 0. ]
[0. 0. 0. 0.2 0. ]]
transport cost: 1.5572261500204791
is permutation matrix? False
is it equal to the one obtained by sorting the points? True
Finally, one can plot the assignement.
Code
fig, ax = plt.subplots(figsize=(12, 6))plot_assignement_1D( ax, x, y, title="1D assignement")plt.show()
2 II. Entropy Regularized Optimal Transport
2.1 II.1 Reminders on Sinkhorn Algorithm
2.1.1 Adding negative entropy as a regularizer
In real ML applications, we often deal with large numbers of points. In this case, cubic complexity linear programming algorithms are too costly. This motivates (among other reasons) the regularized approach \[
\min_{P \in \mathcal{U}(a,b)} \langle C, P \rangle + \epsilon \sum_{ij} P_{ij} [ \log(P_{ij}) - 1].
\] For \(\epsilon\) is sufficiently small, one expects to recover an approximation of the original optimal transport plan.
2.1.2 The Sinkhorn iterates
In order to solve this problem, one can remark that the optimality conditions imply that a solution \(P_\epsilon^*\) necessarily is of the form \(P_\epsilon^* = \text{diag}(u) \, K \, \text{diag}(v)\), where \(K = \exp(-C/\epsilon)\) and \(u,v\) are two non-negative vectors.
\(P_\epsilon^*\) should verify the constraints, i.e. \(P_\epsilon^* \in U(a,b)\), so that \[
P_\epsilon^* 1_m = a \text{ and } (P_\epsilon^*)^T 1_n = b
\] which can be rewritten as \[
u \odot (Kv) = a \text{ and } v \odot (K^T u) = b
\]
Then Sinkhorn’s algorithm alternates between the resolution of these two equations, and reads at iteration \(t\): \[
u^{t+1} \leftarrow \frac{a}{Kv^t} \text{ and } v^{t+1} \leftarrow \frac{b}{K^T u^{t+1}}
\]
2.1.3 Initialization and convergence
Usually, it starts from \(v^{0} = \mathrm{1}_m\) and alternate the above updates until \(\|u^{t+1} \odot (Kv^{t+1}) - a\|_1 + \|v^{t+1} \odot (K^T u^{t+1}) - b\|_1 \leq \tau\), where \(\tau > 0\) is a fixed convergence threshold. Actually, since at the end of each iteration, one exactly has \(v^{t+1} \odot (K^T u^{t+1}) = b\), it just remains to test if \(\|u^{t+1} \odot (Kv^{t+1}) - a\|_1 \leq \tau\).
From an entropic optimal transport plan \(P^*_\epsilon\), we can approximate the optimal transport cost by \(\sum_{i,j=1}^n P^*_{\epsilon_{ij}} C_{ij} = ⟨C, P^*_\epsilon⟩\). For the rest of the section, we call this quantity the entropic optimal transport cost.
2.2 II.2 Using your own Sinkhorn
2.2.1 Sinkhorn Implementation
In this section, you will implement your own version of the Sinkhorn Algorithm.
Question: Complete the following Sinkhorn algorithm, by:
Computing the kernel matrix \(K = \exp(-C / \epsilon)\),
Starting from \(v^{0} = \mathrm{1}_m\),
Alternating the updates \(u^{t+1} \odot (Kv^t) = a\) and \(v^{t+1} \odot (K^T u^{t+1}) = b\),
Remark: you should also use also a maximum number of iterations max_iter, to stop the algorithm after a fixed number of iterations if the convergence is not reached.
Answer:
Code
def sinkhorn( a: np.ndarray, b: np.ndarray, C: np.ndarray, epsilon: float, max_iters: int=100, tau: float=1e-4) -> np.ndarray:""" Sinnkhorn's algorithm. It should output the optimal transport plan. """ K = np.exp( -C / epsilon ) n, m = a.shape[0], b.shape[0] v = np.ones((m,))for _ inrange(max_iters): u = a / K.dot(v) v = b / K.transpose().dot(u)return u[:,None] * v[None,:] * K # u_i, v_j, K_ij
Code
P = sinkhorn(a, b, C_l2_sq, epsilon=1)print(P.sum(axis=0))print(P.sum(axis=1))
def sinkhorn( a: np.ndarray, b: np.ndarray, C: np.ndarray, epsilon: float, max_iters: int=100, tau: float=1e-4) -> np.ndarray:""" Sinnkhorn's algorithm. It should output the optimal transport plan. """ K = np.exp( -C / epsilon ) n, m = a.shape[0], b.shape[0] v = np.ones((m,))for i inrange(max_iters): u = a / K.dot(v) v = b / K.transpose().dot(u)if i %10==0:# compute row sum D(u) K D(v) = u * Kvif np.sum(np.abs(u * K.dot(v) - a)) < tau:print('early termination: '+str(i))breakreturn u[:,None] * v[None,:] * K # u_i, v_j, K_ij
Code
P = sinkhorn(a, b, C_l2_sq, epsilon=1, max_iters=1000)print(P.sum(axis=0))print(P.sum(axis=1))
Now, we can test the Sinkhorn algorithm on the “croissant” transport example.
Question: * Complete the following fuction that takes as input the cost matrix \(C\) and the weights vectors \(a\) and \(b\) and outputs the entropic optimal transport plan and the entropic optimal transport cost using the sinkhorn function. As for the exact transport, it has an option to display the results. * Use that function on the croissant transport to compute and display the optimal plan and the optimal cost for the \(\ell_1, \ell_2\) and \(\ell_2^2\) geometries. * Each time you run the Sinkhorn algorithm, you should use \(\epsilon = 0.1 \cdot \bar{C}\), with \(\bar{C} = \frac{1}{nm} \sum_{i=1}^n \sum_{j=1}^m C_{ij}\) is the mean of the cost matrix. It remains to adapt the \(\epsilon\) value according to the cost matrix, to control the magnitude of the entries of \(C / \epsilon\). Why this strategy? What will happen if \(\epsilon\) is too small compared to the entries of \(C\)?
Answer:
Code
def compute_transport_sinkhorn( C: np.ndarray, a: np.ndarray, b: np.ndarray, epsilon: float, max_iters: int=10_000, tau: float=1e-4, verbose: bool=False,):""" Compute the entropic optimal transport plan and the entropic optimal transport cost for cost matrix ``C`` and weight vectors $a$ and $b$. If ``verbose`` is set to True, it displays the results. """ optimal_plan_sinkhorn = sinkhorn(a, b, C, epsilon, max_iters, tau) optimal_cost_sinkhorn = np.sum(optimal_plan_sinkhorn * C)if verbose:print(f"entropic optimal transport plan: \n{optimal_plan_sinkhorn}" )print(f"entropic transport cost: {optimal_cost_sinkhorn}" )return optimal_plan_sinkhorn, optimal_cost_sinkhorn
The above transport plans are obtained for \(\epsilon = 0.1 \cdot \bar{C}\). Let’s increase epsilon to \(\epsilon = 10 \cdot \bar{C}\) and replot the optimal transport plans to visualize the effect of epsilon.
early termination: 10
early termination: 10
early termination: 10
Note: If the epsilon is large, the distribution is close to uniform.
Question: What do you observe in relation to the transport plans obtained for the exact optimal transport?
Answer:
2.2.3 Sinkhorn consistency
We now show that this Sinkhorn algorithm is consistent with classical optimal transport, using the “croissant” transport example and focusing on the \(\ell_2\) cost.
Question: Complete the following code to compute, for various \(\epsilon'\), values on a regular grid: * Set \(\epsilon = \epsilon' \cdot \bar{C}\), * The deviation of the entropic optimal plan \(P^*_\epsilon\) to the exact optimal plan \(P^*\), namely \(\|P^*_\epsilon - P^*\|_2\). * The deviation of the entropic optimal cost \(\langle C, P^*_\epsilon \rangle\) to the exact optimal plan \(\langle C, P^*_\epsilon \rangle\), namely: \(\langle C, P^*_\epsilon \rangle - \langle C, P^* \rangle\).
We remind that the excat optimal transport plan for the \(\ell_2\) cost is stored as variable optimal_plan_l2_croissant.
Answer:
Code
plan_diff = []distance_diff = []grid = np.linspace(0.01, 5, 100)for epsilon_prime in grid: epsilon = epsilon_prime * np.mean(C_l2) optimal_plan_sinkhorn_l2_croissant, optimal_cost_sinkhorn_l2_croissant = compute_transport_sinkhorn( C=C_l2, a=bakery_prod, b=cafe_prod, epsilon=epsilon, verbose=False )assert optimal_cost_sinkhorn_l2_croissant != np.nan, ("Optimal cost is nan due to numerical instabilities." ) plan_diff.append( np.sum(np.abs(optimal_plan_sinkhorn_l2_croissant - optimal_plan_l2_croissant)) ) distance_diff.append( optimal_cost_sinkhorn_l2_croissant - optimal_cost_l2_croissant )
early termination: 2460
early termination: 220
early termination: 50
early termination: 30
early termination: 20
early termination: 20
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
early termination: 10
/Users/hirofumi48/Library/Python/3.9/lib/python/site-packages/IPython/core/pylabtools.py:152: UserWarning: Data has no positive values, and therefore cannot be log-scaled.
fig.canvas.print_figure(bytes_io, **kw)
Note: The result is different from the lecturer’s.
2.3 II.3 Using OTT
2.3.1 Install OTT
First, you need to install OTT.
Code
%pip install ott-jax
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: ott-jax in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (0.4.5)
Requirement already satisfied: jax>=0.4.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from ott-jax) (0.4.30)
Requirement already satisfied: jaxopt>=0.8 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from ott-jax) (0.8.3)
Requirement already satisfied: numpy>=1.20.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from ott-jax) (1.26.1)
Requirement already satisfied: lineax>=0.0.1 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from ott-jax) (0.0.4)
Requirement already satisfied: jaxlib<=0.4.30,>=0.4.27 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (0.4.30)
Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (0.3.2)
Requirement already satisfied: opt-einsum in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (3.3.0)
Requirement already satisfied: scipy>=1.9 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (1.11.3)
Requirement already satisfied: importlib-metadata>=4.6 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jax>=0.4.0->ott-jax) (6.8.0)
Requirement already satisfied: equinox>=0.11.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from lineax>=0.0.1->ott-jax) (0.11.3)
Requirement already satisfied: jaxtyping>=0.2.20 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from lineax>=0.0.1->ott-jax) (0.2.28)
Requirement already satisfied: typing-extensions>=4.5.0 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from lineax>=0.0.1->ott-jax) (4.8.0)
Requirement already satisfied: zipp>=0.5 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from importlib-metadata>=4.6->jax>=0.4.0->ott-jax) (3.17.0)
Requirement already satisfied: typeguard==2.13.3 in /Users/hirofumi48/Library/Python/3.9/lib/python/site-packages (from jaxtyping>=0.2.20->lineax>=0.0.1->ott-jax) (2.13.3)
Note: you may need to restart the kernel to use updated packages.
Then we load the required pakages.
Code
import jaximport jax.numpy as jnpimport jax.random as randomimport ottfrom ott.geometry import costs, pointcloudfrom ott.problems.linear import linear_problemfrom ott.solvers.linear import sinkhorn
2.3.2 A world about OTT and JAX
OTT is a python library that allows to compute and differentiate the entropic optimal transport. In this lab session, we will focus on entropic optimal transport computation, and not differentiation. differentiation will be takcled later.
OTT is based on JAX, a package similar to PyTorch or TensorFlow, which allows to do automatic differentiation and GPU programming. It also provides useful primitives for efficient computation, such as the just-in-time (jit) compilation or the automatic vectorization map vmap. For more informations on JAX, see the tutorial https://jax.readthedocs.io/en/latest/notebooks/quickstart.html.
Unlike PyTorch or TensorFlow, JAX is very close to numpy thanks to the jax.numpy package, which implements most of the numpy features, but for the JAX data structures. For this lab session, you only need to know how to manipulate jax.numpy Arrays and generate random numbers with jax.random.
First, let’s have a look to jax.numpy and see that it works (almost) exactly as numpy. Usually, one imports jax.numpy as jnp as done in the above cells, and developp as with numpy, by just replacing np by jnp. Note that jax.numpy Arrays are called DeviceArray. For more informations on jax.numpy, see https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html.
Code
d =5u =5* jnp.ones(5)Id = jnp.eye(5)print(type(u))print(f"u = {u}")print(f"Id = {Id}")print(f"Id @ u = {jnp.dot(Id, u)}")print(f"sum(u) = {jnp.sum(u)}")print(f"var(u) = {jnp.var(u)}")
With numpy.random, you can generate random numbers on the fly without giving the seed. For example, np.random.rand() generates a random number \(x \sim U([0, 1])\). Indeed, numpy.random uses an internal seed which is updated each time a random number generating function is called. On the other hand, with jax.random, we must give the seed each time we generate random numbers. To some extent, we want to always control the randomness. Moreover, we do not pass exactly a seed but a jax.random.PRNGKey key which is itself instantiated from a seed. Let’s see it on an example.
Now let’s use the implementation of the OTT Sinkhorn algorithm, on some random weighted point clouds. Then you will, by yourself, use it on the “croissant” transport example.
Let’s first generate the data.
Code
# generate datarng = jax.random.PRNGKey(0)rng1, rng2 = jax.random.split(rng, 2)n, m, d =13, 17, 2x = jax.random.normal(rng1, (n, d))y = jax.random.normal(rng2, (m, d)) +1a = jnp.ones(n) / nb = jnp.ones(m) / m
Then, we have to define a PointCloudgeometry which contains: * the point clouds x and y, * the cost function cost_fn, * the entropic regularization strength epsilon.
Note that the geometry does not contain the weight vectors a and b, these are passed later.
The cost_fn should be an istance of ott.geometry.CostFn. Most of the usual costs are implemented. For example, the three costs \(\ell_1, \ell_2\) and \(\ell_2^2\) are implemented. Here, we will focus on the \(\ell_2\) cost, implemented by ott.geometry.costs.Euclidean. See https://ott-jax.readthedocs.io/en/latest/_autosummary/ott.geometry.costs.CostFn.html#ott.geometry.costs.CostFn for more information on the provided cost_fn.
We still choose epsilon to be \(0.1 \cdot \bar{C}\). To do this, we set relative_epsilon=True when instantiating the geometry. The term relative means that epsilon is chosen relatively to the mean of the cost matrix. Passing then epsilon=0.1, the value of epsilon used by Sinkhorn will be \(0.1 \cdot \bar{C}\).
We then define an optimization problem from this geometry, which is the problem we will solve with the Sinkhorn algorithm. We instantiate this optimization problem as an object of the class linear_problem.LinearProblem. We pass the weight vectors a and b because they define the constraints of the linear problem. Then, we instantiate a Sinkhorn solver, object of the class sinkhorn.Sinkhorn, which we will use to solve this optimization problem.
The OTT library is designed in this way because it allows to solve other optimal transport problems, which do not necessarily have a linear problem structure, and which use other solvers than Sinkhorn.
Code
# create optimization problemot_prob = linear_problem.LinearProblem(geom, a=a, b=b)# create sinkhorn solversolver = sinkhorn.Sinkhorn(ot_prob)# solve the OT problemot_sol = solver(ot_prob)
The ot output object contains several callables and properties, notably a boolean assessing the Sinkhorn convergence, the marginal errors throughtout iterations and the optimal transport plan.
Code
print(" Sinkhorn has converged: ", ot_sol.converged,"\n","Error upon last iteration: ", ot_sol.errors[(ot_sol.errors >-1)][-1],"\n","Sinkhorn required ", jnp.sum(ot_sol.errors >-1)," iterations to converge. \n","entropic OT cost: ", jnp.sum(ot_sol.matrix * ot_sol.geom.cost_matrix),)
Sinkhorn has converged: True
Error upon last iteration: 0.00019063428
Sinkhorn required 5 iterations to converge.
entropic OT cost: 29.436863
Question: Compute the entropic optimal transport plan and cost for the “croissant” transport problem, with \(\ell_2\) cost and \(\epsilon = 0.1 \cdot \bar{C}\). Then, plot the optimal transport plan.