VAE:変分自己符号化器

PyTorch によるハンズオン

Deep
Sampling
Python
Author

司馬博文

Published

7/28/2024

概要

変分自己符号化器 (VAE) は,データを周辺分布にもつ潜在変数モデルを変分 Bayes 推論によって学習するアルゴリズムである. 従来計算・近似が困難であった変分下界を,ニューラルネットワークによって近似するアプローチである. 学習されたベイズ潜在変数モデルからはサンプリングによって新たなデータを生成することができるため,深層生成モデルの一つに分類されることもある.

1 VAE (D. Kingma and Welling, 2014)

1.1 導入

PyTorch を用いることで詳細を省略し,VAE の構造を概観することとする.

import torch
import torch.nn as nn

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid

今回は,MNIST データセットを用い,隠れ次元 400 を通じて潜在次元 200 まで圧縮する.

dataset_path = '~/hirofumi/datasets'

DEVICE = torch.device("mps")

batch_size = 100

x_dim = 784
hidden_dim = 400
latent_dim = 200

lr = 1e-3

epochs = 30
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


mnist_transform = transforms.Compose([
        transforms.ToTensor(),
])

kwargs = {'num_workers': 0, 'pin_memory': True} 

train_dataset = MNIST(dataset_path, transform=mnist_transform, train=True, download=True)
test_dataset  = MNIST(dataset_path, transform=mnist_transform, train=False, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False, **kwargs)

PyTorch の Dataset と DataLoader は,訓練やテスト用のデータセットの簡単なアクセスと,それに対する iterable オブジェクトを提供する.

まず,次のようにして仮想環境を用意する:

python3 -m venv VAE
source VAE/bin/activate
pip install torch

M2 Mac では Metal Performance Shaders (MPS) という Apple の GPU アクセラレーション技術が利用可能で,PyTorch 1.12 からはこれをサポートしている.

import torch
print(torch.__version__)
print(torch.backends.mps.is_available())
2.4.0
True

上記のエラーは,DataLoader が並列処理によりデータを読み込むことに失敗したことを意味する.

メモリ不足も考えられるが,num_workers=0 として単一プロセスで実行することでもエラーが抑えられる.

今回は軽量な計算であるから,これで良いということである.

1.2 モデルの定義

1.2.1 エンコーダー

エンコーダーはデータを受け取り,2層の全結合隠れ層を通じて,「平均」と「対数分散」の名前がついた計 400 次元の潜在表現を得る.

class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.FC_input = nn.Linear(input_dim, hidden_dim)
        self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_mean  = nn.Linear(hidden_dim, latent_dim)
        self.FC_var   = nn.Linear(hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, x):
        h_       = self.LeakyReLU(self.FC_input(x))
        h_       = self.LeakyReLU(self.FC_input2(h_))
        mean     = self.FC_mean(h_)
        log_var  = self.FC_var(h_)
        
        return mean, log_var
1
nn.Linear は PyTorch による全結合層 \(y=xA^\top+b\) の実装である.
2
ここまで2層の全結合層にデータを通して,最終的な出力h_を得ており,次の段階で最終的な潜在表現を得る.
3
最後の隠れ層の出力h_に関して平均と対数分散という名前のついた最終的な出力を,やはり全結合層を通じて得る(最終層なので活性化なし).

1.2.2 デコーダー

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h     = self.LeakyReLU(self.FC_hidden(x))
        h     = self.LeakyReLU(self.FC_hidden2(h))
        
        x_hat = torch.sigmoid(self.FC_output(h))
        return x_hat
1
最後の出力は,エンコーダーとは違い,シグモイド関数を通して確率分布x_hatとする.

1.2.3 モデル

VAE はエンコーダーとデコーダーを連結し,1つのニューラルネットワークとして学習する.

class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)
        z = mean + var*epsilon
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))
        x_hat            = self.Decoder(z)
        
        return x_hat, mean, log_var
1
これは サンプリングイプシロン と呼ばれる値である.
2
ここで reparametrization trick を行っている.
3
入力 x があったならば,まずエンコーダーに通して mean, log_var を得る.
4
元々 log_var の名前の通り対数分散として扱うこととしていたので,2で割り指数関数に通すことで標準偏差を得る.この平均と標準偏差について reparametrization trick を実行し,デコーダーに繋ぐ.
5
デコーダーではデータの潜在表現 z を受け取り,デコードしたものを x_hat とする.
6
返り値は,デコーダーの出力 x_hat だけでなく,潜在表現 mean, log_var も含むことに注意.
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)

model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)
1
.to(DEVICE) により,モデルを M2 Mac の MPS デバイス上に移送している.

1.3 モデルの訓練

最適化には Adam (D. P. Kingma and Ba, 2017) を用い,バイナリ交差エントロピー(BCE)を用いる.これは nn.BCELoss に実装がある.

from torch.optim import Adam

BCE_loss = nn.BCELoss()

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD      = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD


optimizer = Adam(model.parameters(), lr=lr)

ここでの損失関数は,真のデータ x をデコーダーが復元できているかを交差エントロピーで測った reproduction_loss と,潜在表現がどれだけ \(\mathrm{N}_d(0,I_d),d=200\) に近いかを KL 乖離度で測った KLD の和で定義されている.1

VAE の標準的な目的関数 とは違う形をしていることに注意.

import time

print("Start training VAE...")
model.train()

start_time = time.time()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.view(batch_size, x_dim)
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        loss = loss_function(x, x_hat, mean, log_var)
        
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (batch_idx*batch_size))

total_time = time.time() - start_time
print("Finish!! Total time: ", total_time)
1
PyTorch のモデルオブジェクトを訓練モードにするメソッド.Dropout や Batch Normalization 層がある場合は,これにより訓練時の挙動を示すようになる.
2
事前に定めた batch_size に従ってバッチを展開.
3
データを GPU に移動.
4
勾配をゼロに初期化するとのこと.
Start training VAE...
    Epoch 1 complete!   Average Loss:  174.45440591089314
    Epoch 2 complete!   Average Loss:  129.366964569856
    Epoch 3 complete!   Average Loss:  116.69913125065213
    Epoch 4 complete!   Average Loss:  112.35357475675605
    Epoch 5 complete!   Average Loss:  109.92540430339629
    Epoch 6 complete!   Average Loss:  108.33972220954195
    Epoch 7 complete!   Average Loss:  107.14160705668301
    Epoch 8 complete!   Average Loss:  106.20647785371452
    Epoch 9 complete!   Average Loss:  105.42789199446995
    Epoch 10 complete!  Average Loss:  104.9622272798414
    Epoch 11 complete!  Average Loss:  104.39062490218072
    Epoch 12 complete!  Average Loss:  103.96445004369261
    Epoch 13 complete!  Average Loss:  103.49826466963168
    Epoch 14 complete!  Average Loss:  103.10589126408598
    Epoch 15 complete!  Average Loss:  102.75672558104654
    Epoch 16 complete!  Average Loss:  102.47924416671015
    Epoch 17 complete!  Average Loss:  102.25440002543301
    Epoch 18 complete!  Average Loss:  102.09212112961707
    Epoch 19 complete!  Average Loss:  101.80551883347245
    Epoch 20 complete!  Average Loss:  101.61719139646807
    Epoch 21 complete!  Average Loss:  101.45744191164962
    Epoch 22 complete!  Average Loss:  101.27843764672892
    Epoch 23 complete!  Average Loss:  101.1599205544397
    Epoch 24 complete!  Average Loss:  101.04871442638773
    Epoch 25 complete!  Average Loss:  100.8713441999687
    Epoch 26 complete!  Average Loss:  100.73917756808223
    Epoch 27 complete!  Average Loss:  100.68475770163815
    Epoch 28 complete!  Average Loss:  100.58486650928631
    Epoch 29 complete!  Average Loss:  100.44012970836812
    Epoch 30 complete!  Average Loss:  100.39852615687604
Finish!! Total time:  130.87328815460205

1.4 モデルの評価

テスト用データの最初のバッチについて処理し,入力データと出力データを見比べてみる.

model.eval()

with torch.no_grad():
    for batch_idx, (x, _) in enumerate(tqdm(test_loader)):
        x = x.view(batch_size, x_dim)
        x = x.to(DEVICE)
        
        x_hat, _, _ = model(x)


        break
1
勾配評価を無効化するコンテクストマネージャーで,メモリの使用を節約できるという.
  0%|          | 0/100 [00:00<?, ?it/s]  0%|          | 0/100 [00:00<?, ?it/s]
import matplotlib.pyplot as plt

def show_image(x, idx):
    x = x.view(batch_size, 28, 28)

    fig = plt.figure()
    plt.imshow(x[idx].cpu().numpy())

show_image(x, idx=0)
show_image(x_hat, idx=0)
(a) 左がテストデータ,右がその VAE による復元
(b)
図 1

左が入力で右が出力である.

1.5 データの生成

ここで,エンコーダを取り外してデコーダーからデータを生成する.

損失関数(第 1.3 節)には,潜在空間におけるデータを標準正規分布に近付けるための項が入っていたため,データの潜在表現は極めて標準正規分布に近いとみなすことにする.

すると,潜在表現と同じ次元の正規乱数から,データセットに極めて似通ったデータが生成できるだろう.

with torch.no_grad():
    noise = torch.randn(batch_size, latent_dim).to(DEVICE)
    generated_images = decoder(noise)

save_image(generated_images.view(batch_size, 1, 28, 28), 'generated_sample.png')
for i in range(4):
    show_image(generated_images, idx=i)
(a)
(b)
(c)
(d)
図 2

2 VQ-VAE (van den Oord et al., 2017)

2.1 導入

DEVICE = torch.device("mps")

batch_size = 128
img_size = (32, 32)

input_dim = 3
hidden_dim = 512
latent_dim = 16
n_embeddings= 512
output_dim = 3
commitment_beta = 0.25

lr = 2e-4

epochs = 50

print_step = 50
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


mnist_transform = transforms.Compose([
        transforms.ToTensor(),
])

kwargs = {'num_workers': 1, 'pin_memory': True} 

train_dataset = CIFAR10(dataset_path, transform=mnist_transform, train=True, download=True)
test_dataset  = CIFAR10(dataset_path, transform=mnist_transform, train=False, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False,  **kwargs)
Files already downloaded and verified
Files already downloaded and verified

2.2 モデルの定義

2.2.1 エンコーダー

VQ-VAE は画像への応用を念頭に置いているため,エンコーダーには CNN アーキテクチャ を採用する.

class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, kernel_size=(4, 4, 3, 1), stride=2):
        super(Encoder, self).__init__()
        
        kernel_1, kernel_2, kernel_3, kernel_4 = kernel_size
        
        self.strided_conv_1 = nn.Conv2d(input_dim, hidden_dim, kernel_1, stride, padding=1)
        self.strided_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_2, stride, padding=1)
        
        self.residual_conv_1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_3, padding=1)
        self.residual_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_4, padding=0)
        
        self.proj = nn.Conv2d(hidden_dim, output_dim, kernel_size=1)
        
    def forward(self, x):
        
        x = self.strided_conv_1(x)
        x = self.strided_conv_2(x)
        
        x = F.relu(x)
        y = self.residual_conv_1(x)
        y = y+x
        
        x = F.relu(y)
        y = self.residual_conv_2(x)
        y = y+x
        
        y = self.proj(y)
        return y

3 参考文献

本稿は,Minsu Jackson Kang 氏 による チュートリアル を参考にした.

VAE には数々の変種があるが,その PyTorch による簡単な実装は Anand K Subramanianこのレポジトリ にリストアップされている.

VAE の潜在表現は t-SNE などを用いて可視化でき,(Murphy, 2023, p. 635) の例などでも,潜在空間において手書き数字がクラスごとによく分離されていることが確認できる.

References

Kingma, D. P., and Ba, J. (2017). Adam: A method for stochastic optimization.
Kingma, D., and Welling, M. (2014). Auto-encoding variational bayes. In International conference on learning representations,Vol. 2.
Murphy, K. P. (2023). Probabilistic machine learning: Advanced topics. MIT Press.
van den Oord, A., Vinyals, O., and Kavukcuoglu, K. (2017). Neural discrete representation learning. In Advances in neural information processing systems,Vol. 30.

Footnotes

  1. なお,mean.pow(2) は Julia の mean.^2 に同じ.↩︎