Pytorch Lightning で生成モデル — Autoencoder
最近は主に画像の深層生成モデルに取り組んでいます。ライブラリとしてはしばらくは PyTorch を使っていたのですが、最近、 構造的に書くことのできるラッパーとして Pytorch Lightning を使い始めました。
今回、練習としてオートエンコーダを実装し、手書き数字データセットである MNIST およびその変種を用いて簡単な実験を行ったので、記事にまとめます。
Autoencoder は次元の圧縮と復元を行う
オートエンコーダは、次元圧縮を行うエンコーダと、次元復元を行うデコーダのペアからなる生成モデルです。
上図のように、「データセットの元をエンコーダに通して低次元空間で表現したのち、そのままデコーダに通して元の次元に復元した際、できるだけもとのデータを再現できるようにする」という方式で学習が行われます。ここで、データ空間より次元の小さい圧縮先の空間を潜在空間と称しています。
今回は最も単純な例として、いくつかの 全結合層+活性化関数 のみからなるオートエンコーダを用います。また、潜在空間の次元は図示のために 2 次元としておきます。
Pytorch Lightning で整ったコードを書く
PyTorch Lightning は、PyTorch コードを構造化するためのライブラリです。通常の PyTorch では、学習を行うモデルのクラスは nn.Module
を継承しますが、 Lightning では、nn.Module
を継承している pl.LighningModule
を継承し、configure_optimizer
および training_step
メソッドを書きます。
以下2つのコードブロックは公式ドキュメントの例です。
import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy class LitModel(pl.LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.0005)
通常の PyTorch での学習では、エポック単位およびバッチ単位で for ループを回していました。 training_step
メソッドは、そうしたループの内側の処理に相当するメソッドであり、各バッチを受け取って損失を返します。Lightning では、for ループを回す部分をクラスにまかせることができるため、訓練を実行する部分は以下のように非常に簡単になります。
# dataloader dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) train_loader = DataLoader(dataset) # init model model = LitModel() # most basic trainer, uses good defaults trainer = pl.Trainer() trainer.fit(model, train_loader)
その他、GPU を利用する際は Trainer のコンストラクタに使用する GPU 数を与えるだけでよく、GPUにデータを移すコードを書く必要がなくなります。また、並列化の際に、データセットのダウンロード処理が一度かつ一度だけ行われることを保証したりなど、細かなところで色々と面倒を見てくれます。pl.LightningModule
は nn.Module
を継承しており、機能的に上位互換であるため、 nn.Module
で使えるメソッドはすべて nn.Module
に期待する通りに使えます。
類似のラッパーとしては、catalyst, fastai, ignite などがあります。私は他を実際に試したわけではないので、比較については以下の記事などが参考になるかと思います。
PyTorch 三国志(Ignite・Catalyst・Lightning) - Qiita
今回、2020年8月10日現在で stable には未実装な pl.DataModule
を使用するため 、バージョンは 0.9.0rc2
を使用しました。
pip install pytorch-lightning==0.9.0rc2
Lightning はまだ破壊的変更も多く、過去の記事のコードをコピペしてもそのままでは動かないことが多いです。設計思想が大幅に変わるわけではないにせよ、バージョンの合った公式ドキュメントを確認することをおすすめします。
実装
モデル本体
import なども含めたコードの全容は github
https://github.com/optie-f/PL_AutoEncoder
を見てもらうとして、主な部分を解説します。
まず、線形層と活性化関数からなるモジュールを作ります。
class LAN(nn.Module): """ Linear-Activation-Normalization. """ def __init__(self, in_dim, out_dim, activation=nn.ReLU(True), norm=nn.Identity()): super(LAN, self).__init__() self.L = nn.Linear(in_dim, out_dim) self.A = activation self.N = norm def forward(self, x): z = self.L(x) z = self.A(z) z = self.N(z) return z
一応 BatchNormalization などを配置できるようにもしておきましたが、今回は使用しません。
次に、エンコーダおよびデコーダとなるモジュールです。
class DENcoder(nn.Module): def __init__(self, dimentions, mid_activation, last_activation): super(DENcoder, self).__init__() layers = [] in_dims = dimentions[:-1] for i, in_dim in enumerate(in_dims): activation = last_activation if i + 1 == len(in_dims) else mid_activation out_dim = dimentions[i + 1] layers.append( LAN(in_dim, out_dim, activation=activation) ) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x)
encoder でも decoder でも使うので DENcoder です。他の名称は思いつきませんでした。コンストラクタでは、dimentions
として「入力データの次元および各層の出力の次元」の一次元整数配列、mid_activation
として中間層の(最終出力以外の)活性化関数、last_activation
として最終出力への活性化関数をそれぞれ渡して、インスタンスを作成します。
なぜ各層を直接書かないか。全結合層を積んでいくときには、 [Linear(d1, d2), Linear(d2, d3), ...]
というように、各層の入力次元を前の層の出力次元と揃える必要があります。これを直接書くのではなく、[d1, d2, d3, ... ]
という配列を渡して作ることで、より指定しやすく、また、深さや幅を調整しやすくするという意図でこのような実装にしています。
次に、全体のモデルです。ここで pl.LightningModule
が登場します。以下で各メソッドについて説明します。
class AutoEncoder(pl.LightningModule): def __init__(self, in_dimentions): super(AutoEncoder, self).__init__() out_dimentions = list(reversed(in_dimentions)) self.encoder = DENcoder(in_dimentions, nn.ReLU(True), nn.Identity()) self.decoder = DENcoder(out_dimentions, nn.ReLU(True), nn.Tanh()) self.criterion = nn.MSELoss() def forward(self, img): x = img.view(img.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) img_recon = x_hat.view(img.size()) return img_recon def training_step(self, batch, batch_idx): img, _ = batch img_recon = self.forward(img) loss = self.criterion(img, img_recon) if self.global_step % self.trainer.row_log_interval == 0: sqrt_nb = int(sqrt(img.size(0))) self.logger.experiment.add_image( "image/original", make_grid(img, sqrt_nb, normalize=True, range=(-1, 1)), self.global_step ) self.logger.experiment.add_image( "image/reconstructed", make_grid(img_recon, sqrt_nb, normalize=True, range=(-1, 1)), self.global_step ) return {'loss': loss, 'log': {'loss': loss}} def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5) return optimizer
コンストラクタについて。ここでは、エンコーダとデコーダの幅と深さは対称にしているため、コンストラクタでは in_dimentions
引数としてエンコーダ用の配列のみを受け取り、デコーダはこれを反転させることで作っています。また、損失関数を self.criterion に代入しており、今回は平均自乗誤差です。
forward
メソッドについて。このモデルへ入力されるのは (batchsize, C, H, W) という形状の画像ですが、全結合層への入力では一次元にする必要があるので、 (batchsize, CxHxW) というベクトルの形に変形し、出力時に再び画像の形に戻す処理を行っています。
training_step
メソッドについて。前述の通り、訓練ループの内側に相当する関数で、batch および batch の添字を受け取り、損失を返します。
Lightning では、Trainer のコンストラクタの引数に logger (後述)を渡すことで、 Tensorboard などを用いて学習のリアルタイム計測監視を行うことができます。これに関係する処理もtraining_step
に書きます。例えば上記のコードのように、返り値を dict にした上で、'log': {'loss': loss,...}
のような要素を追加することで、損失関数など指定したスカラー値の訓練中の変動がプロットされます。
スカラーのプロットのほか、画像などを表示する場合は明示的にメソッドを呼び出す必要があるようです。ここでは、訓練における現在のステップ数が記録されている self.global_step
および、 Trainer のコンストラクタに渡す self.trainer.row_log_interval
から、「特定のステップ数のときに、元画像と復元画像を並べて表示する」という処理を書いています。
DataModule
ここで、pl.LightningDataModule
というものを用意しておきます。これは、train,validation,test 用のdataset および Dataloader を統一的に管理するものであり、trainer.fit()
あるいは trainer.test()
に渡すことで、訓練時・バリデーション時・テスト時のデータセットの読み分けを行います。
dataset_classes = { 'MNIST': MNIST, 'KMNIST': KMNIST, 'FashionMNIST': FashionMNIST, 'CIFAR10': CIFAR10 } class NormalizeAllChannel(): """ a normalization transform using given mean & std but accepts any number of channels. By default, [0, 1] -> [-1, 1] """ def __init__(self, mean=0.5, std=0.5): self.mean = 0.5 self.std = 0.5 def __call__(self, x): return (x - self.mean) / self.std class DataModule(pl.LightningDataModule): def __init__(self, data_dir='./data', data_name='', batchsize=32): super().__init__() self.data_dir = data_dir self.batch_size = batchsize self.transform = transforms.Compose([ transforms.ToTensor(), NormalizeAllChannel() ]) if data_name in dataset_classes: self.Dataset = dataset_classes[data_name] else: raise NotImplementedError def prepare_data(self): self.Dataset(self.data_dir, download=True) def setup(self, stage=None): if stage == 'fit' or stage is None: self.dataset = self.Dataset( self.data_dir, train=True, transform=self.transform ) size = len(self.dataset) t, v = (int(size * 0.9), int(size * 0.1)) t += (t + v != size) self.dataset_train, self.dataset_val = random_split(self.dataset, [t, v]) if stage == 'test' or stage is None: self.dataset_test = self.Dataset( self.data_dir, train=False, transform=self.transform ) def train_dataloader(self): return DataLoader( self.dataset_train, batch_size=self.batch_size, ) def val_dataloader(self): return DataLoader( self.dataset_val, batch_size=self.batch_size, ) def test_dataloader(self): return DataLoader( self.dataset_test, batch_size=self.batch_size, )
必須なメソッドは prepare_data
, setup
,train_dataloader
の 3 つです。今回は、データセットとして MNIST, KMNIST, FashionMNIST, CIFAR10 のみを扱う予定なので、コンストラクタの引数で文字列で data_name を受け取り、これに応じて Dataset クラスを保持しておきます(インスタンスではなく、クラスそのものです)。
まず prepare_data
が呼ばれます。ここでデータのダウンロード処理を行うことが想定されているようです。ここに書いておくことで、並列化した際にもうまくやってくれるようです。
次に setup
が呼ばれます。このとき、呼ばれたのが trainer.fit()
か trainer.test()
かに応じて、stage
引数に 'fit'
か 'test'
かが渡されるので、これに応じて dataset オブジェクトを準備しておきます。
train_dataloader
はそのあとに呼ばれるようです。DataLoader オブジェクトを返します。
訓練
訓練用のコードは以下です。上から順番に見ていきます。
def train(hparams): train_loader = DataModule(data_name=hparams.dataset_name) train_loader.prepare_data() train_loader.setup() in_dim = np.prod(train_loader.dataset[0][0].size()) dimentions = [in_dim, 512, 128, 64, 12, 2] autoEncoder = AutoEncoder(dimentions) print(autoEncoder) checkpoint_callback = ModelCheckpoint( save_top_k=1, verbose=True, monitor='loss', mode='min', prefix='' ) logger = TensorBoardLogger('log', name=hparams.dataset_name) trainer = Trainer( logger=logger, default_root_dir='./log', checkpoint_callback=checkpoint_callback, row_log_interval=50, max_epochs=hparams.max_epochs, gpus=hparams.gpus, tpu_cores=hparams.tpu_cores ) trainer.fit(autoEncoder, train_loader) def main(): parser = ArgumentParser() parser.add_argument('--gpus', default=None) parser.add_argument('--tpu_cores', default=None) parser.add_argument('--max_epochs', default=50) parser.add_argument('--dataset_name', default="MNIST") args = parser.parse_args() train(args) if __name__ == "__main__": main()
本来、train_loader.prepare_data()
および train_loader.setup()
を我々で呼ぶ必要はないですし、むしろおそらく呼ぶべきではないのですが、ここでは入力の次元をデータセットから計算する必要があるために初めに呼んでしまっています。ここはもっと良いやり方があるのかもしれません。
ModelCheckpoint は学習しながらモデルを保存するためのコールバックを司るクラスで、インスタンスを Trainer のコンストラクタに渡して使います。loss などの指定した指標が最も少なくなった上位 k 件を保存する、などといった指定ができます。
TensorBoardLogger は、Tensorboard オブジェクトの Lightning 向けラッパーで、こちらもインスタンスを Trainer に渡して使います。
これらを Trainer のコンストラクタに渡すことでインスタンスを得ています。使用する GPU 数や TPU のコア数もここで指定します。
実験と可視化
MNISTの潜在空間
学習および結果の確認は Google Colabolatory で行いました。以下のコードブロックはそのままノートブックのセルへの記述です。
まず、リポジトリからクローンし、Lightning のインストールを行って作業ディレクトリを移動します。
!git clone https://github.com/optie-f/PL_AutoEncoder
!pip install pytorch-lightning==0.9.0rc2
%cd PL_AutoEncoder
学習しながら確認するため、あらかじめ Tensorboard を表示しておきます。
%load_ext tensorboard %tensorboard --logdir log/
その後、訓練を始めます。前述のコードにある通り、--dataset_name
のデフォルト値は MNIST になっているので、MNIST に対する訓練が始まります。
!python train.py --gpus 1
訓練が始まったあと、さきほど表示した Tensorboard を更新すると、リアルタイムで損失や復元画像を確認することができます。以下図は訓練終了後のものですが、左側がデータセットの画像、右側が復元された画像です。
ややぼやけており、最上段の「8」や「5」など怪しいものも散見されますが、概ね元の数字形状は視認できるといってよいでしょう。
また、テストデータから 10000 件をエンコードし、2次元潜在空間における散布図を作成し、ラベル(実際に何の数であるか)ごとに色分けをすると、以下のようになりました。
なんとなく各数字ごとにまとまっていた位置に現れているような気はします。これだけではわかりにくいので、今度は逆に、潜在空間内で、$(0, 0)$ を中心とした $50 × 50$, 幅 1 のグリッドをとり、格子点をデコーダを通すことでどのような数字として現れるかを見ることにします。
例えば「8」と「2」が隣接した領域にあり、その間には「8」とも「2」ともつかないような数字が並んでいたりすることが観察されます。
こうした潜在空間における分布は訓練をするたびに異なった現れ方をします。別の試行での潜在空間は以下のようなものでした。
様子は異なっていますが、概ね中心から放射状に広がっているように見えます。
潜在空間を移動する点を逐次復元することで、文字が変化するアニメーションなども作れそうなイメージはつきますね。
可視化に使用したコードは以下の通りです。
from modules.data_loader import DataModule from modules.model import AutoEncoder import numpy as np import torch import matplotlib.pyplot as plt import seaborn as sns import os data_name = "MNIST" testsize = 10000 # 画像(とラベルの読み込み) dataModule = DataModule(data_name=data_name, batchsize=testsize) dataModule.setup(stage="test") test_loader = dataModule.test_dataloader() images, labels = iter(test_loader).next() # 最新バージョンの最新チェックポイントのパス versions_dir = os.path.join('./log/', data_name) versions = sorted(os.listdir(versions_dir)) ckpts_dir = os.path.join(versions_dir, versions[-1], 'checkpoints/') ckpts = sorted(os.listdir(ckpts_dir)) ckpt_path = os.path.join(ckpts_dir, ckpts[-1]) # モデルの読み込み dimentions = [np.prod(images[0].size()), 512, 128, 64, 12, 2] autoEncoder = AutoEncoder(dimentions) autoEncoder.load_state_dict(torch.load(ckpt_path)["state_dict"]) autoEncoder.freeze() # テストデータの 10000 枚の画像が潜在空間でどのように分布するかの可視化 z = autoEncoder.encoder(images.view(10000, -1)) np_z = z.to('cpu').detach().numpy().copy() np_labels = labels.to('cpu').detach().numpy().copy() sns.set() plt.figure(figsize=(12, 10)) plt.scatter(np_z[:,0], np_z[:,1], c=np_labels, s=1) plt.colorbar() plt.show() # 潜在空間での格子点を考える。その各点はどのような画像になるかの可視化 latent_grid = np.array([[(x+0.5), (y+0.5)] for x in range(-25,25) for y in range(-25,25)]) z_grid = torch.tensor(latent_grid, dtype=torch.float32) recon = autoEncoder.decoder(z_grid) np_recon = recon.view(50, 50, 28, 28).to('cpu').detach().numpy().copy() sns.set(style='dark') plt.figure(figsize=(20, 20)) plt.imshow(np_recon.swapaxes(1,2).reshape(28*50, 28*50), cmap='gray') plt.show()
FashionMNIST, KMNIST, CIFAR10 でも試す
今回はMNISTのほか、ファッション雑貨の低解像度写真?データセット FasionMNIST、崩し字データセットの KMNIST、カラー写真の CIFAR10 でも同様の実験を行っていました。以下に結果を載せます。
FasionMNIST
ファッション雑貨データセットです。「MNISTは簡単すぎてベンチマークにもならない」という課題から作られたそうです。
復元の様子。
ストライプ模様のシャツなどは模様がぼやけてしまっていますが、種別は保持されていそうです。ただ左上のように暗めの色のズボンなどは何が何だか分からなくなっています。こうした点から、MNISTの数字認識からの難易度上昇を感じさせます。
潜在空間マップ。
長袖の上着とカバンが接する領域、長袖からワンピースへの遷移などが面白いですね。
KMNIST
崩し字です。こちらは比較的最近のデータセットではなかったでしょうか。字種が多く、しかも各種類ごとに数が均等ではないそうで、認識問題としての難易度は跳ね上がりそうです。
復元の様子。
「つ」「ハ」「れ」は復元できているのはわかりますが、それらと形が近い字が「つ」「ハ」「れ」に吸収されてしまっている感じがあります。
潜在空間マップ。
左左下あたりが存在しない字になっている気がします。
CIFAR10
一般物体認識に用いられる写真のデータセットです。MNIST が $28 \times 28 \times 1 = 784$ 次元であり、MNIST変種もその程度であったのに対して、CIFAR10 は $32 \times 32 \times 3 = 3072$ 次元あります。そもそも airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck というクラスの写真分類であり、一つのクラス内でもMNISTよりずっとバリエーションが豊かです。つまり、難しいです。
復元の様子。
マッピングをしたかったので潜在空間を2次元にすることにこだわっていたのですが、このようにまったく復元できていないので、限界がありそうです。
よくわからなかった点
前項の可視化コードで、モデルを読み込む際に以下のようにしていますが、
dimentions = [np.prod(images[0].size()), 512, 128, 64, 12, 2] autoEncoder = AutoEncoder(dimentions) autoEncoder.load_state_dict(torch.load(ckpt_path)["state_dict"])
これは本来、公式ドキュメントの例を見る限り、以下のように書くだけでよいはずです。
autoEncoder = AutoEncoder.load_from_checkpoint(ckpt_path)
しかし、これを実行すると、このようなエラーに遭遇します。
46 def __init__(self, in_dimentions): 47 super(AutoEncoder, self).__init__() ---> 48 out_dimentions = list(reversed(in_dimentions)) 49 self.encoder = DENcoder(in_dimentions, nn.ReLU(True), nn.Identity()) 50 self.decoder = DENcoder(out_dimentions, nn.ReLU(True), nn.Tanh()) TypeError: 'dict' object is not reversible
これはバグということになるのでしょうか。追って Lightning のコードなどを確認できればしたいところです。
おわりに
PyTorch Lightning でオートエンコーダを実装しました。次があれば、今回のコードを拡張しながら変分オートエンコーダ(VAE)でも実装しようと思います。