Optie研

パソコンで絵や動画を作る方法について

Pytorch Lightning で生成モデル — Autoencoder

 最近は主に画像の深層生成モデルに取り組んでいます。ライブラリとしてはしばらくは PyTorch を使っていたのですが、最近、 構造的に書くことのできるラッパーとして Pytorch Lightning を使い始めました。

 今回、練習としてオートエンコーダを実装し、手書き数字データセットである MNIST およびその変種を用いて簡単な実験を行ったので、記事にまとめます。

Autoencoder は次元の圧縮と復元を行う

 オートエンコーダは、次元圧縮を行うエンコーダと、次元復元を行うデコーダのペアからなる生成モデルです。

f:id:Optie_f:20200811063227j:plain

 上図のように、「データセットの元をエンコーダに通して低次元空間で表現したのち、そのままデコーダに通して元の次元に復元した際、できるだけもとのデータを再現できるようにする」という方式で学習が行われます。ここで、データ空間より次元の小さい圧縮先の空間を潜在空間と称しています。

 今回は最も単純な例として、いくつかの 全結合層+活性化関数 のみからなるオートエンコーダを用います。また、潜在空間の次元は図示のために 2 次元としておきます。

Pytorch Lightning で整ったコードを書く

 PyTorch Lightning は、PyTorch コードを構造化するためのライブラリです。通常の PyTorch では、学習を行うモデルのクラスは nn.Module を継承しますが、 Lightning では、nn.Module を継承している pl.LighningModule を継承し、configure_optimizer および training_step メソッドを書きます。

以下2つのコードブロックは公式ドキュメントの例です。

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

class LitModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0005)

 通常の PyTorch での学習では、エポック単位およびバッチ単位で for ループを回していました。 training_step メソッドは、そうしたループの内側の処理に相当するメソッドであり、各バッチを受け取って損失を返します。Lightning では、for ループを回す部分をクラスにまかせることができるため、訓練を実行する部分は以下のように非常に簡単になります。

# dataloader
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

# init model
model = LitModel()

# most basic trainer, uses good defaults
trainer = pl.Trainer()
trainer.fit(model, train_loader)

 その他、GPU を利用する際は Trainer のコンストラクタに使用する GPU 数を与えるだけでよく、GPUにデータを移すコードを書く必要がなくなります。また、並列化の際に、データセットのダウンロード処理が一度かつ一度だけ行われることを保証したりなど、細かなところで色々と面倒を見てくれます。pl.LightningModulenn.Module を継承しており、機能的に上位互換であるため、 nn.Module で使えるメソッドはすべて nn.Module に期待する通りに使えます。

 類似のラッパーとしては、catalyst, fastai, ignite などがあります。私は他を実際に試したわけではないので、比較については以下の記事などが参考になるかと思います。

PyTorch 三国志(Ignite・Catalyst・Lightning) - Qiita

 今回、2020年8月10日現在で stable には未実装な pl.DataModule を使用するため 、バージョンは 0.9.0rc2 を使用しました。

pip install pytorch-lightning==0.9.0rc2

Lightning はまだ破壊的変更も多く、過去の記事のコードをコピペしてもそのままでは動かないことが多いです。設計思想が大幅に変わるわけではないにせよ、バージョンの合った公式ドキュメントを確認することをおすすめします。

実装

モデル本体

 import なども含めたコードの全容は github

https://github.com/optie-f/PL_AutoEncoder

を見てもらうとして、主な部分を解説します。

 まず、線形層と活性化関数からなるモジュールを作ります。

class LAN(nn.Module):
    """
    Linear-Activation-Normalization.
    """

    def __init__(self, in_dim, out_dim, activation=nn.ReLU(True), norm=nn.Identity()):
        super(LAN, self).__init__()
        self.L = nn.Linear(in_dim, out_dim)
        self.A = activation
        self.N = norm

    def forward(self, x):
        z = self.L(x)
        z = self.A(z)
        z = self.N(z)
        return z

一応 BatchNormalization などを配置できるようにもしておきましたが、今回は使用しません。

 次に、エンコーダおよびデコーダとなるモジュールです。

class DENcoder(nn.Module):
    def __init__(self, dimentions, mid_activation, last_activation):
        super(DENcoder, self).__init__()
        layers = []
        in_dims = dimentions[:-1]

        for i, in_dim in enumerate(in_dims):
            activation = last_activation if i + 1 == len(in_dims) else mid_activation
            out_dim = dimentions[i + 1]
            layers.append(
                LAN(in_dim, out_dim, activation=activation)
            )

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

encoder でも decoder でも使うので DENcoder です。他の名称は思いつきませんでした。コンストラクタでは、dimentionsとして「入力データの次元および各層の出力の次元」の一次元整数配列、mid_activation として中間層の(最終出力以外の)活性化関数、last_activation として最終出力への活性化関数をそれぞれ渡して、インスタンスを作成します。

 なぜ各層を直接書かないか。全結合層を積んでいくときには、 [Linear(d1, d2), Linear(d2, d3), ...] というように、各層の入力次元を前の層の出力次元と揃える必要があります。これを直接書くのではなく、[d1, d2, d3, ... ] という配列を渡して作ることで、より指定しやすく、また、深さや幅を調整しやすくするという意図でこのような実装にしています。

 次に、全体のモデルです。ここで pl.LightningModule が登場します。以下で各メソッドについて説明します。

class AutoEncoder(pl.LightningModule):
    def __init__(self, in_dimentions):
        super(AutoEncoder, self).__init__()
        out_dimentions = list(reversed(in_dimentions))
        self.encoder = DENcoder(in_dimentions, nn.ReLU(True), nn.Identity())
        self.decoder = DENcoder(out_dimentions, nn.ReLU(True), nn.Tanh())
        self.criterion = nn.MSELoss()

    def forward(self, img):
        x = img.view(img.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        img_recon = x_hat.view(img.size())
        return img_recon

    def training_step(self, batch, batch_idx):
        img, _ = batch
        img_recon = self.forward(img)
        loss = self.criterion(img, img_recon)

        if self.global_step % self.trainer.row_log_interval == 0:
            sqrt_nb = int(sqrt(img.size(0)))
            self.logger.experiment.add_image(
                "image/original",
                make_grid(img, sqrt_nb, normalize=True, range=(-1, 1)),
                self.global_step
            )
            self.logger.experiment.add_image(
                "image/reconstructed",
                make_grid(img_recon, sqrt_nb, normalize=True, range=(-1, 1)),
                self.global_step
            )

        return {'loss': loss, 'log': {'loss': loss}}

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5)
        return optimizer

 コンストラクタについて。ここでは、エンコーダとデコーダの幅と深さは対称にしているため、コンストラクタでは in_dimentions 引数としてエンコーダ用の配列のみを受け取り、デコーダはこれを反転させることで作っています。また、損失関数を self.criterion に代入しており、今回は平均自乗誤差です。

 forward メソッドについて。このモデルへ入力されるのは (batchsize, C, H, W) という形状の画像ですが、全結合層への入力では一次元にする必要があるので、 (batchsize, CxHxW) というベクトルの形に変形し、出力時に再び画像の形に戻す処理を行っています。

 training_step メソッドについて。前述の通り、訓練ループの内側に相当する関数で、batch および batch の添字を受け取り、損失を返します。

 Lightning では、Trainer のコンストラクタの引数に logger (後述)を渡すことで、 Tensorboard などを用いて学習のリアルタイム計測監視を行うことができます。これに関係する処理もtraining_step に書きます。例えば上記のコードのように、返り値を dict にした上で、'log': {'loss': loss,...} のような要素を追加することで、損失関数など指定したスカラー値の訓練中の変動がプロットされます。

 スカラーのプロットのほか、画像などを表示する場合は明示的にメソッドを呼び出す必要があるようです。ここでは、訓練における現在のステップ数が記録されている self.global_step および、 Trainer のコンストラクタに渡す self.trainer.row_log_interval から、「特定のステップ数のときに、元画像と復元画像を並べて表示する」という処理を書いています。

DataModule

 ここで、pl.LightningDataModule というものを用意しておきます。これは、train,validation,test 用のdataset および Dataloader を統一的に管理するものであり、trainer.fit() あるいは trainer.test() に渡すことで、訓練時・バリデーション時・テスト時のデータセットの読み分けを行います。

dataset_classes = {
    'MNIST': MNIST,
    'KMNIST': KMNIST,
    'FashionMNIST': FashionMNIST,
    'CIFAR10': CIFAR10
}

class NormalizeAllChannel():
    """
    a normalization transform using given mean & std
    but accepts any number of channels.
    By default, [0, 1] -> [-1, 1]
    """

    def __init__(self, mean=0.5, std=0.5):
        self.mean = 0.5
        self.std = 0.5

    def __call__(self, x):
        return (x - self.mean) / self.std

class DataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./data', data_name='', batchsize=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batchsize
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            NormalizeAllChannel()
        ])
        if data_name in dataset_classes:
            self.Dataset = dataset_classes[data_name]
        else:
            raise NotImplementedError

    def prepare_data(self):
        self.Dataset(self.data_dir, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.dataset = self.Dataset(
                self.data_dir,
                train=True,
                transform=self.transform
            )
            size = len(self.dataset)
            t, v = (int(size * 0.9), int(size * 0.1))
            t += (t + v != size)
            self.dataset_train, self.dataset_val = random_split(self.dataset, [t, v])

        if stage == 'test' or stage is None:
            self.dataset_test = self.Dataset(
                self.data_dir,
                train=False,
                transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(
            self.dataset_train,
            batch_size=self.batch_size,
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset_val,
            batch_size=self.batch_size,
        )

    def test_dataloader(self):
        return DataLoader(
            self.dataset_test,
            batch_size=self.batch_size,
        )

必須なメソッドは prepare_data, setup,train_dataloader の 3 つです。今回は、データセットとして MNIST, KMNIST, FashionMNIST, CIFAR10 のみを扱う予定なので、コンストラクタの引数で文字列で data_name を受け取り、これに応じて Dataset クラスを保持しておきます(インスタンスではなく、クラスそのものです)。

 まず prepare_data が呼ばれます。ここでデータのダウンロード処理を行うことが想定されているようです。ここに書いておくことで、並列化した際にもうまくやってくれるようです。

次に setup が呼ばれます。このとき、呼ばれたのが trainer.fit()trainer.test() かに応じて、stage 引数に 'fit''test' かが渡されるので、これに応じて dataset オブジェクトを準備しておきます。

 train_dataloader はそのあとに呼ばれるようです。DataLoader オブジェクトを返します。

訓練

訓練用のコードは以下です。上から順番に見ていきます。

def train(hparams):
    train_loader = DataModule(data_name=hparams.dataset_name)
    train_loader.prepare_data()
    train_loader.setup()

    in_dim = np.prod(train_loader.dataset[0][0].size())
    dimentions = [in_dim, 512, 128, 64, 12, 2]

    autoEncoder = AutoEncoder(dimentions)
    print(autoEncoder)

    checkpoint_callback = ModelCheckpoint(
        save_top_k=1,
        verbose=True,
        monitor='loss',
        mode='min',
        prefix=''
    )

    logger = TensorBoardLogger('log', name=hparams.dataset_name)
    trainer = Trainer(
        logger=logger,
        default_root_dir='./log',
        checkpoint_callback=checkpoint_callback,
        row_log_interval=50,
        max_epochs=hparams.max_epochs,
        gpus=hparams.gpus,
        tpu_cores=hparams.tpu_cores
    )

    trainer.fit(autoEncoder, train_loader)

def main():
    parser = ArgumentParser()
    parser.add_argument('--gpus', default=None)
    parser.add_argument('--tpu_cores', default=None)
    parser.add_argument('--max_epochs', default=50)
    parser.add_argument('--dataset_name', default="MNIST")
    args = parser.parse_args()

    train(args)

if __name__ == "__main__":
    main()

 本来、train_loader.prepare_data() および train_loader.setup() を我々で呼ぶ必要はないですし、むしろおそらく呼ぶべきではないのですが、ここでは入力の次元をデータセットから計算する必要があるために初めに呼んでしまっています。ここはもっと良いやり方があるのかもしれません。

 ModelCheckpoint は学習しながらモデルを保存するためのコールバックを司るクラスで、インスタンスを Trainer のコンストラクタに渡して使います。loss などの指定した指標が最も少なくなった上位 k 件を保存する、などといった指定ができます。

 TensorBoardLogger は、Tensorboard オブジェクトの Lightning 向けラッパーで、こちらもインスタンスを Trainer に渡して使います。

 これらを Trainer のコンストラクタに渡すことでインスタンスを得ています。使用する GPU 数や TPU のコア数もここで指定します。

実験と可視化

MNISTの潜在空間

 学習および結果の確認は Google Colabolatory で行いました。以下のコードブロックはそのままノートブックのセルへの記述です。

 まず、リポジトリからクローンし、Lightning のインストールを行って作業ディレクトリを移動します。

!git clone https://github.com/optie-f/PL_AutoEncoder
!pip install pytorch-lightning==0.9.0rc2
%cd PL_AutoEncoder

  学習しながら確認するため、あらかじめ Tensorboard を表示しておきます。

%load_ext tensorboard
%tensorboard --logdir log/

 その後、訓練を始めます。前述のコードにある通り、--dataset_name のデフォルト値は MNIST になっているので、MNIST に対する訓練が始まります。

!python train.py --gpus 1

 訓練が始まったあと、さきほど表示した Tensorboard を更新すると、リアルタイムで損失や復元画像を確認することができます。以下図は訓練終了後のものですが、左側がデータセットの画像、右側が復元された画像です。

f:id:Optie_f:20200811053435p:plain

ややぼやけており、最上段の「8」や「5」など怪しいものも散見されますが、概ね元の数字形状は視認できるといってよいでしょう。

 また、テストデータから 10000 件をエンコードし、2次元潜在空間における散布図を作成し、ラベル(実際に何の数であるか)ごとに色分けをすると、以下のようになりました。

f:id:Optie_f:20200811053850p:plain

なんとなく各数字ごとにまとまっていた位置に現れているような気はします。これだけではわかりにくいので、今度は逆に、潜在空間内で、$(0, 0)$ を中心とした $50 × 50$, 幅 1 のグリッドをとり、格子点をデコーダを通すことでどのような数字として現れるかを見ることにします。

f:id:Optie_f:20200811054130p:plain

例えば「8」と「2」が隣接した領域にあり、その間には「8」とも「2」ともつかないような数字が並んでいたりすることが観察されます。

 こうした潜在空間における分布は訓練をするたびに異なった現れ方をします。別の試行での潜在空間は以下のようなものでした。

f:id:Optie_f:20200811054558p:plain

様子は異なっていますが、概ね中心から放射状に広がっているように見えます。

潜在空間を移動する点を逐次復元することで、文字が変化するアニメーションなども作れそうなイメージはつきますね。

可視化に使用したコードは以下の通りです。

from modules.data_loader import DataModule
from modules.model import AutoEncoder
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os


data_name = "MNIST"
testsize = 10000

# 画像(とラベルの読み込み)
dataModule = DataModule(data_name=data_name, batchsize=testsize)
dataModule.setup(stage="test")
test_loader = dataModule.test_dataloader()
images, labels = iter(test_loader).next()

# 最新バージョンの最新チェックポイントのパス
versions_dir = os.path.join('./log/', data_name)
versions = sorted(os.listdir(versions_dir))
ckpts_dir = os.path.join(versions_dir, versions[-1], 'checkpoints/')
ckpts = sorted(os.listdir(ckpts_dir))
ckpt_path = os.path.join(ckpts_dir, ckpts[-1])

# モデルの読み込み
dimentions = [np.prod(images[0].size()), 512, 128, 64, 12, 2]
autoEncoder = AutoEncoder(dimentions)
autoEncoder.load_state_dict(torch.load(ckpt_path)["state_dict"])
autoEncoder.freeze()

# テストデータの 10000 枚の画像が潜在空間でどのように分布するかの可視化
z = autoEncoder.encoder(images.view(10000, -1))
np_z = z.to('cpu').detach().numpy().copy()
np_labels = labels.to('cpu').detach().numpy().copy()

sns.set()
plt.figure(figsize=(12, 10))
plt.scatter(np_z[:,0], np_z[:,1], c=np_labels, s=1)
plt.colorbar()
plt.show()

# 潜在空間での格子点を考える。その各点はどのような画像になるかの可視化
latent_grid = np.array([[(x+0.5), (y+0.5)] for x in range(-25,25) for y in range(-25,25)])
z_grid = torch.tensor(latent_grid, dtype=torch.float32)
recon = autoEncoder.decoder(z_grid)
np_recon = recon.view(50, 50, 28, 28).to('cpu').detach().numpy().copy()

sns.set(style='dark')
plt.figure(figsize=(20, 20))
plt.imshow(np_recon.swapaxes(1,2).reshape(28*50, 28*50), cmap='gray')
plt.show()

FashionMNIST, KMNIST, CIFAR10 でも試す

 今回はMNISTのほか、ファッション雑貨の低解像度写真?データセット FasionMNIST、崩し字データセットの KMNIST、カラー写真の CIFAR10 でも同様の実験を行っていました。以下に結果を載せます。

FasionMNIST

 ファッション雑貨データセットです。「MNISTは簡単すぎてベンチマークにもならない」という課題から作られたそうです。

復元の様子。

f:id:Optie_f:20200811055344p:plain

ストライプ模様のシャツなどは模様がぼやけてしまっていますが、種別は保持されていそうです。ただ左上のように暗めの色のズボンなどは何が何だか分からなくなっています。こうした点から、MNISTの数字認識からの難易度上昇を感じさせます。

潜在空間マップ。

f:id:Optie_f:20200811055412p:plain

長袖の上着とカバンが接する領域、長袖からワンピースへの遷移などが面白いですね。

KMNIST

 崩し字です。こちらは比較的最近のデータセットではなかったでしょうか。字種が多く、しかも各種類ごとに数が均等ではないそうで、認識問題としての難易度は跳ね上がりそうです。

復元の様子。

f:id:Optie_f:20200811055956p:plain

「つ」「ハ」「れ」は復元できているのはわかりますが、それらと形が近い字が「つ」「ハ」「れ」に吸収されてしまっている感じがあります。

潜在空間マップ。

f:id:Optie_f:20200811060942p:plain

左左下あたりが存在しない字になっている気がします。

CIFAR10

 一般物体認識に用いられる写真のデータセットです。MNIST が $28 \times 28 \times 1 = 784$ 次元であり、MNIST変種もその程度であったのに対して、CIFAR10 は $32 \times 32 \times 3 = 3072$ 次元あります。そもそも airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck というクラスの写真分類であり、一つのクラス内でもMNISTよりずっとバリエーションが豊かです。つまり、難しいです。

復元の様子。

f:id:Optie_f:20200811061649p:plain

マッピングをしたかったので潜在空間を2次元にすることにこだわっていたのですが、このようにまったく復元できていないので、限界がありそうです。

よくわからなかった点

前項の可視化コードで、モデルを読み込む際に以下のようにしていますが、

dimentions = [np.prod(images[0].size()), 512, 128, 64, 12, 2]
autoEncoder = AutoEncoder(dimentions)
autoEncoder.load_state_dict(torch.load(ckpt_path)["state_dict"])

これは本来、公式ドキュメントの例を見る限り、以下のように書くだけでよいはずです。

autoEncoder = AutoEncoder.load_from_checkpoint(ckpt_path)

しかし、これを実行すると、このようなエラーに遭遇します。

46     def __init__(self, in_dimentions):
     47         super(AutoEncoder, self).__init__()
---> 48         out_dimentions = list(reversed(in_dimentions))
     49         self.encoder = DENcoder(in_dimentions, nn.ReLU(True), nn.Identity())
     50         self.decoder = DENcoder(out_dimentions, nn.ReLU(True), nn.Tanh())

TypeError: 'dict' object is not reversible

これはバグということになるのでしょうか。追って Lightning のコードなどを確認できればしたいところです。

おわりに

PyTorch Lightning でオートエンコーダを実装しました。次があれば、今回のコードを拡張しながら変分オートエンコーダ(VAE)でも実装しようと思います。