import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torchvision.utils import save_image, make_grid
A Blog Entry on Bayesian Computation by an Applied Mathematician
$$
$$
1 VAE (Kingma and Welling, 2014)
1.1 導入
PyTorch
を用いることで詳細を省略し,VAE の構造を概観することとする.
今回は,MNIST データセットを用い,隠れ次元 400 を通じて潜在次元 200 まで圧縮する.
= '~/hirofumi/datasets'
dataset_path
= torch.device("mps")
DEVICE
= 100
batch_size
= 784
x_dim = 400
hidden_dim = 200
latent_dim
= 1e-3
lr
= 30 epochs
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
= transforms.Compose([
mnist_transform
transforms.ToTensor(),
])
= {'num_workers': 0, 'pin_memory': True}
kwargs
= MNIST(dataset_path, transform=mnist_transform, train=True, download=True)
train_dataset = MNIST(dataset_path, transform=mnist_transform, train=False, download=True)
test_dataset
= DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
train_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, **kwargs) test_loader
PyTorch の Dataset と DataLoader は,訓練やテスト用のデータセットの簡単なアクセスと,それに対する iterable オブジェクトを提供する.
1.2 モデルの定義
1.2.1 エンコーダー
エンコーダーはデータを受け取り,2層の全結合隠れ層を通じて,「平均」と「対数分散」の名前がついた計 400 次元の潜在表現を得る.
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.FC_input = nn.Linear(input_dim, hidden_dim)
self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
self.FC_mean = nn.Linear(hidden_dim, latent_dim)
self.FC_var = nn.Linear(hidden_dim, latent_dim)
self.LeakyReLU = nn.LeakyReLU(0.2)
self.training = True
def forward(self, x):
= self.LeakyReLU(self.FC_input(x))
h_ = self.LeakyReLU(self.FC_input2(h_))
h_ = self.FC_mean(h_)
mean = self.FC_var(h_)
log_var
return mean, log_var
- 1
-
nn.Linear
は PyTorch による全結合層 \(y=xA^\top+b\) の実装である. - 2
-
ここまで2層の全結合層にデータを通して,最終的な出力
h_
を得ており,次の段階で最終的な潜在表現を得る. - 3
-
最後の隠れ層の出力
h_
に関して平均と対数分散という名前のついた最終的な出力を,やはり全結合層を通じて得る(最終層なので活性化なし).
1.2.2 デコーダー
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
self.FC_output = nn.Linear(hidden_dim, output_dim)
self.LeakyReLU = nn.LeakyReLU(0.2)
def forward(self, x):
= self.LeakyReLU(self.FC_hidden(x))
h = self.LeakyReLU(self.FC_hidden2(h))
h
= torch.sigmoid(self.FC_output(h))
x_hat return x_hat
- 1
-
最後の出力は,エンコーダーとは違い,シグモイド関数を通して確率分布
x_hat
とする.
1.2.3 モデル
VAE はエンコーダーとデコーダーを連結し,1つのニューラルネットワークとして学習する.
class Model(nn.Module):
def __init__(self, Encoder, Decoder):
super(Model, self).__init__()
self.Encoder = Encoder
self.Decoder = Decoder
def reparameterization(self, mean, var):
= torch.randn_like(var).to(DEVICE)
epsilon = mean + var*epsilon
z return z
def forward(self, x):
= self.Encoder(x)
mean, log_var = self.reparameterization(mean, torch.exp(0.5 * log_var))
z = self.Decoder(z)
x_hat
return x_hat, mean, log_var
- 1
- これは サンプリングイプシロン と呼ばれる値である.
- 2
- ここで reparametrization trick を行っている.
- 3
-
入力
x
があったならば,まずエンコーダーに通してmean
,log_var
を得る. - 4
-
元々
log_var
の名前の通り対数分散として扱うこととしていたので,2で割り指数関数に通すことで標準偏差を得る.この平均と標準偏差について reparametrization trick を実行し,デコーダーに繋ぐ. - 5
-
デコーダーではデータの潜在表現
z
を受け取り,デコードしたものをx_hat
とする. - 6
-
返り値は,デコーダーの出力
x_hat
だけでなく,潜在表現mean
,log_var
も含むことに注意.
= Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
encoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)
decoder
= Model(Encoder=encoder, Decoder=decoder).to(DEVICE) model
- 1
-
.to(DEVICE)
により,モデルを M2 Mac の MPS デバイス上に移送している.
1.3 モデルの訓練
最適化には Adam (Kingma and Ba, 2017) を用い,バイナリ交差エントロピー(BCE)を用いる.これは nn.BCELoss
に実装がある.
from torch.optim import Adam
= nn.BCELoss()
BCE_loss
def loss_function(x, x_hat, mean, log_var):
= nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
reproduction_loss = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
KLD
return reproduction_loss + KLD
= Adam(model.parameters(), lr=lr) optimizer
ここでの損失関数は,真のデータ x
をデコーダーが復元できているかを交差エントロピーで測った reproduction_loss
と,潜在表現がどれだけ \(\mathrm{N}_d(0,I_d),d=200\) に近いかを KL 乖離度で測った KLD
の和で定義されている.1
VAE の標準的な目的関数 とは違う形をしていることに注意.
1.4 モデルの評価
テスト用データの最初のバッチについて処理し,入力データと出力データを見比べてみる.
eval()
model.
with torch.no_grad():
for batch_idx, (x, _) in enumerate(tqdm(test_loader)):
= x.view(batch_size, x_dim)
x = x.to(DEVICE)
x
= model(x)
x_hat, _, _
break
- 1
- 勾配評価を無効化するコンテクストマネージャーで,メモリの使用を節約できるという.
0%| | 0/100 [00:00<?, ?it/s] 0%| | 0/100 [00:00<?, ?it/s]
import matplotlib.pyplot as plt
def show_image(x, idx):
= x.view(batch_size, 28, 28)
x
= plt.figure()
fig
plt.imshow(x[idx].cpu().numpy())
=0)
show_image(x, idx=0) show_image(x_hat, idx
左が入力で右が出力である.
1.5 データの生成
ここで,エンコーダを取り外してデコーダーからデータを生成する.
損失関数(第 1.3 節)には,潜在空間におけるデータを標準正規分布に近付けるための項が入っていたため,データの潜在表現は極めて標準正規分布に近いとみなすことにする.
すると,潜在表現と同じ次元の正規乱数から,データセットに極めて似通ったデータが生成できるだろう.
with torch.no_grad():
= torch.randn(batch_size, latent_dim).to(DEVICE)
noise = decoder(noise)
generated_images
1, 28, 28), 'generated_sample.png')
save_image(generated_images.view(batch_size, for i in range(4):
=i) show_image(generated_images, idx
2 VQ-VAE (van den Oord et al., 2017)
2.1 導入
= torch.device("mps")
DEVICE
= 128
batch_size = (32, 32)
img_size
= 3
input_dim = 512
hidden_dim = 16
latent_dim = 512
n_embeddings= 3
output_dim = 0.25
commitment_beta
= 2e-4
lr
= 50
epochs
= 50 print_step
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
= transforms.Compose([
mnist_transform
transforms.ToTensor(),
])
= {'num_workers': 1, 'pin_memory': True}
kwargs
= CIFAR10(dataset_path, transform=mnist_transform, train=True, download=True)
train_dataset = CIFAR10(dataset_path, transform=mnist_transform, train=False, download=True)
test_dataset
= DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
train_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, **kwargs) test_loader
Files already downloaded and verified
Files already downloaded and verified
2.2 モデルの定義
2.2.1 エンコーダー
VQ-VAE は画像への応用を念頭に置いているため,エンコーダーには CNN アーキテクチャ を採用する.
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, kernel_size=(4, 4, 3, 1), stride=2):
super(Encoder, self).__init__()
= kernel_size
kernel_1, kernel_2, kernel_3, kernel_4
self.strided_conv_1 = nn.Conv2d(input_dim, hidden_dim, kernel_1, stride, padding=1)
self.strided_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_2, stride, padding=1)
self.residual_conv_1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_3, padding=1)
self.residual_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_4, padding=0)
self.proj = nn.Conv2d(hidden_dim, output_dim, kernel_size=1)
def forward(self, x):
= self.strided_conv_1(x)
x = self.strided_conv_2(x)
x
= F.relu(x)
x = self.residual_conv_1(x)
y = y+x
y
= F.relu(y)
x = self.residual_conv_2(x)
y = y+x
y
= self.proj(y)
y return y
3 参考文献
本稿は,Minsu Jackson Kang 氏 による チュートリアル を参考にした.
VAE には数々の変種があるが,その PyTorch による簡単な実装は Anand K Subramanian の このレポジトリ にリストアップされている.
VAE の潜在表現は t-SNE などを用いて可視化でき,(Murphy, 2023, p. 635) の例などでも,潜在空間において手書き数字がクラスごとによく分離されていることが確認できる.
References
Footnotes
なお,
mean.pow(2)
は Julia のmean.^2
に同じ.↩︎