あの論文を検証してみた! - シリーズ第5回 - Neural Processes 論文(実験編)

こんにちは!ブレインズコンサルティングの大下です。

今回は、「あの論文を検証してみた!」のシリーズ第5回、前回理論的(数学的・確率論的)な側面を解説したNeural Processes の論文の実験編です。 DeepMind社のTensorflow による実装を参考に、PyTorch 版を作成しました。 意外と、DeepMind社の実装のように、ある程度抽象化したりモジュール構成を整理したPyTorch版Neural Processesの実装がない印象だったので、独自作成しました。

今回の実験編では、Few-Shot Leaning をテーマに、いくつかのデータセット(※)で検証していきます。

(※)使ったデータセットは、Toy Dataset、MNIST、Fashion-MNIST、Kuzushiji-MNIST です。

検証環境

まずは、動作確認に使った検証環境を明記しておきます。

  • Ubuntu 18.04.1 LTS (Bionic Beaver)
    • CPU: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz / 4 cores / 8 processors
    • Memory: 48GiB
    • HDD: 約190GiB (Available なサイズ)
  • Python 3.6.5
  • PyCharm-community-2018.3.3
  • PyTorch 1.0.0
  • Colaboratory(GPU) / pytorch 1.0.1.post2

本検証におけるアウトプット

コード

画像

Toy Dataset

まず、DeepMind社がGitHub で公開しているToy Dataset をPyTorch に移植し、実験しました。 ネットワーク構成や各種ハイパーパラメータは、DeepMind社によるTensorflow実装に極力寄せています。 DeepMind社の実装例に従い、高々16個の教師データ(Context とTargetの組)のみで学習して、どれだけ予測ができるかを見ていきます。 このデータセットは、多変量正規分布から400時点(==Targetの座標点の個数)をサンプルすることで、各系列(軌道)が生成されます。 つまり、各時点での確率変数が独立でない一般的な多変量正規分布を、Neural Process (独立した正規分布/正規過程)で近似する検証です。

Trainset

まず、学習によって予測( \hat{y}_T \sigma)がどう変わるかを見てみます。

下図の各グラフについて、横軸が xを表し、縦軸が y です。 点線が、Target( x_T, y_Tの座標の軌跡・系列)、黒点がTargetからランダムサンプリングされたContext( x_C, y_Cの座標点)、青実線が予測( \hat{y}_T)、薄い青帯が予測からの標準偏差 \sigma)を表しています。

f:id:bci-oshita:20190327115923p:plain
trainset-graphs

epoch==1000 ぐらいでは、まだ学習が進んでおらず、10000ぐらいで、まぁまぁそれっぽい感じのグラフになってます。 20000までいくと、一部のContext が集中している箇所では、分散(標準偏差)が0に近くなっている(帯がほぼ線と一致している)箇所があり、ややoverfit 気味かも、という印象です。

学習状況

学習時のloss, 予測グラフ(gif アニメーション)を確認します。(このグラフは、examples スクリプトを --visdom オプションを指定すると参照できるようになります。)

f:id:bci-oshita:20190327184856g:plain
training-toy

trainset の loss は、いい感じに下降傾向です。 各予測グラフは、trainset (size==16)の中の先頭5個の系列(多変量正規分布からサンプルされた系列)を表示しています。 予測の青線( \hat{y}_T)は、Context由来の近似分布 q_Cからサンプルした10個の zから生成したもので、0.2秒おきに異なる zのサンプルによる予測を表示しています。 testset の loss は、上昇していってますね。。(^-^; これは、trainset が高々16(==batch_size)なので、学習サンプル数(バリエーション数)が少ないことが原因として考えられます。

次が最終的な結果(epoch==20000)です。

f:id:bci-oshita:20190327185433p:plain
trained-toy

先の予測グラフと比較すると、多少、オバーフィットしている?と思えるくらいフィットしている感があります。  \hat{y}_T \pm{\sigma} の幅もかなり狭くなっています。(つまり、学習によって、 \sigmaの値が小さくなるように学習している)

この5つのサンプル系列を見る限り、trainset の系列については、 \hat{y}_T \pm{\sigma} の幅の中に概ね入っており、学習自体は十分できていると考えられます。 ただ、loss(test) (右下)のグラフは、収束する気配は見られず、大きくばらつく状況です。 やはり、学習データ(trainset)のバリエーションが、高々16系列で、推定させるのは無理があるという印象です。

実際、batch_size==100 に変更して実行すると、以下のようにloss(test) の分散が縮まります。( y軸の目盛に注目します。)

以下の図が、batch_size==100 の学習後の結果です。(epoch==20000直後のグラフ)

f:id:bci-oshita:20190328144857p:plain
trained-toy batch_size=100

batch_size==1000 辺りで、実行しようとしましたが、時間が結構かかりそうだったので、未実施です。ご興味ある方は、実験してみてください。 もしかしたら、ネットワークの複雑度を上げる必要が出てくるかもしれません。(batch_size によらず、epoch=10000ぐらいは必要そうな印象です。)

Testset

次は、testset に対するグラフを見ていきます。

epoch==1000 直後のtestset の推定

以下のグラフは、epoch==1000 直後のContextを入力としたTargetの予測グラフです。

f:id:bci-oshita:20190326122908p:plain
epoch==1000

Target は、一律400の座標で構成し、Context は、Targetの個数を超えないように、ランダムで個数が選定されます。 このepoch==1000 の例では、Context の5点から、Targetの400点(軌跡)を推定した結果です。 学習があまり進んでいないので、予測が概ね \hat{y}_T = 0 になっています。

epoch==10000 直後のtestset の推定

f:id:bci-oshita:20190326124304p:plain
epoch==10000

10000 epoch あたりから、学習結果がそれっぽくなっています。 このepoch==10000 の例では、Context の13点から、推定した結果です。 予測の線が、Context(黒点のみの系列) の平均・トレンドっぽい線のようになっていますが、Context 情報がない両端は、Target(点線)に比べて大きく離れていってます。 後は、分散が小さいことも踏まえて、trainset にオーバーフィットしている感じがします。

epoch == 20000 直後のtestset の推定

f:id:bci-oshita:20190326143601p:plain
epoch==20000
このepoch==20000 の例では、Context の19点から、推定した結果です。

うーん、、。最後(右端)のContext 点以外、すべて無視した曲線を描いていらっしゃいます。(^-^; これは、trainset に、この系列(Context, Target)に近いデータがなく、学習データも高々16パターンなので、 このパターンは「知るか!」という感じの予測結果ですね。

trainset のところでも触れましたが、batch_size を増やすと幾分、まともな感じになります。 ただ、それでもバリエーションとして網羅はできていないようで、外すときは大きく外す予測(Target)を生成する印象でした。

単純に考えて、400変数を持つ、多変量正規分布の近似を考えるので、相当たくさんの系列を学習させる必要があるのが目に見えています。 相当数の系列の学習が必要な割には、高々16の系列データで、比較的 y=0 から遠く外れていない系列については、それっぽい感じのグラフを描ける点にも驚きです。(だんだん、 y=0から離れていくような、振幅が激しい波を持つようなグラフのフィットはうまくいかない様子)

この辺が、独立した正規分布列での近似モデルっぽいというか、それは、限界だよね、という感じがします。

MNIST

次は、論文でも使われているMNIST で、結果を見てみます。 正答率を確認しやすいように、batch_size==100として、実験しました。 epoch 数は、実行時間と収束具合を踏まえ、epochs==300 としました。 モデルのネットワーク(MLP部分)は、Toy Dataset と同じネットワークだと、loss は下がるものの、 trainsetですら全然数字を再現できなかったので、入力・出力側の次元数を増やすように変更しました。(複雑度を上げました。) 学習率は、0.001 です。(Toy Dataset の学習率は、 10^{-4}と小さすぎる印象だったので、その10倍の 10^{-3}にしました。) Context は、Target (28x28 pixcels)から、25%~75%の間の%で、ランダムサンプルしたバラバラのpixel です。(画像に変換するときは、非サンプリングのpixcelの値をゼロパディングした画像として表示しています。)

ざっくりいうと、50%ぐらいの確率で、testset の生成に成功した印象です。 多量データが必要な深層学習という文脈で考えると、たかが100サンプルで、半分ぐらいを再現できるというのは、悪くない印象です。

Trainset

epoch==300 直後の学習データ画像の再現具合を見てみます。

うまくいった例

左から順に、Context画像、予測画像、Target画像です。

f:id:bci-oshita:20190328152649p:plain
train-example-success

これは、ほぼ再現できています。 さすが学習データということもあり、ほぼほぼTarget画像と同じ感じで再現されています。

うまくいかなかった例

f:id:bci-oshita:20190328153109p:plain
train-example-not-success

小さい画像で見ると、結構似ているんですが、拡大すると「7」に見えます。やはり、pixelの分布が似ている文字の違いは、微妙という感じです。 (正規分布なので、似ている文字との中間画像(平均ベクトル)に落ち着く)

Testset

次は、Testset を見ていきます。

うまくいった例

f:id:bci-oshita:20190328153742p:plain
test-example-success

すばらしい!ちゃんと「3」が生成されています。 ただ、trainset とは異なり、Target とは異なる「3」を生成しています。 これは、「3」の平均画像がうまく学習されており、Context画像に近い、学習した平均画像を生成したと考えられます。

うまくいかなかった例

Testset のうまくいかなかった例は、結構あります(笑) いくつか見ていきましょう。

これはありがち

f:id:bci-oshita:20190328155005p:plain
test-example-not-success_1

Context画像(左)を人間が見たら、まぁ「7」でしょうという感じですが、 確かに、どれだけ欠けているかわからなかったら、「9」という可能性も・・・、 これは、「7」と「9」の分布が近いところにあることがわかります。 実際、真ん中の予測画像を見ると、「9」のまるいところ・部分の左下が薄くなっています。 この薄いところがなかったとしたら、確かに「7」に見える気もします。 迷いがあることが、この中間的な画像として表現されている感じがしますね。

確かに、その可能性も・・・

f:id:bci-oshita:20190328154354p:plain
test-example-not-success_2

おっと!真ん中をつないじゃいましたか! 確かに、Contextだけ見たら、その可能性もありますね。(どれぐらい、pixcel が欠けているかの情報、当然、Contextとしては渡していないですし。) それにしても、よく「0」に近い、雪だるまみたいな「8」をうまいこと生成できましたね。ちょっと感心しました。

むむむ・・・

f:id:bci-oshita:20190328155614p:plain
test-example-not-success_3

見た感じ、「3」と「8」を混ぜたような画像を生成しています。 そして、よく見ると、真ん中あたりにContext 画像の名残も若干見て取れます。 これは、学習データに近い数字画像がなく、混ざった感じがします。 最初は気づかなかったですが、Context画像について、左上から右下に線「白いpixel」を付与したら、確かに、「8」になりそうですね。

その自信はどこから?!

f:id:bci-oshita:20190328160036p:plain
test-example-not-success_4

なんとなく、迷わず「3」を生成している印象を受けました。 これは、Context の白いpixcel のマッチ数が多くなるのが、生成された この「3」なのかな、という印象です。 この辺りの間違いは、人間では、ほぼあり得ないことで、ブラックボックスが怖いとか、信用できないとか言われる所以なのかな、という印象です。 この現象の対策案としては、普通のMNISTのタスク(ラベル「0」から「9」の判定)と組み合わせることで、多少解決しそうな予感がします。 ただ、今回は、純粋なNeural Processes モデルの検証なので、この実験は、スコープ外ですので、悪しからず。

Fashion-MNIST

論文では、CelebA で検証されていましたが、この記事では、より検証が簡単にできるFashion-MNISTを使ってみます。 MNISTで使ったネットワークやハイパーパラメータをそのまま使用し、データセットだけ変更したモデル、条件で実験しました。(examples/train_mnists.py の --datasetオプションを、"fashion" に変更しただけです。)

Fashion-MNIST のラベルは、以下の通り。(引用:zalandoresearch

ラベル 記述
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

Trainset

うまくいった例

Trainset は、だいたいがうまく再現されています。

f:id:bci-oshita:20190329115840p:plain
fashion-train-example-success

正直、Context 画像を見ただけでは、何のアイテムか想像できなかったです。 なにより、自分の絵描き能力では、ここまで、それっぽく再現できないです(笑)

うまくいかなかった例

Trainset では、ラベルが異なるレベルの間違った画像生成はありませんでした。(可視化に使った1ミニバッチ・100サンプルの中では。)

しいて上げるとすれば、以下の画像です。

f:id:bci-oshita:20190329120509p:plain
fashion-train-example-not-success

サンダルの外形は、うまく再現されているものの、内部的なデザインまでは再現できなかったようです。多少Ankle bootっぽい感じもします。(これは、正規分布ベースなので、仕方ないかな、という印象です。)

学習としては、特に問題なくできている印象を受けました。

Testset

では、testset を見ていきます。

うまくいった例

ほぼ一致

f:id:bci-oshita:20190329121005p:plain
fashion-test-example-success

この画像は、うまくいった画像の中でも比較的、良く再現できている印象の画像です。 この画像は、Sneaker の代表的な画像に近い感じがするので、Targetにも近い画像がうまく生成されたのだと思います。

ラベルは一致

f:id:bci-oshita:20190329121440p:plain
fashion-test-example-success_2

Targetとは、若干異なるが同じ種類のラベル(Pullover)っぽい画像を生成できています。 さすがに、模様は無視されて生成されますが。。(模様は、trainset でもうまく再現できていませんでした。)

うまくいかなかった例

着るとはくの違い

f:id:bci-oshita:20190329122133p:plain
fashion-test-example-not-success

これは、比較的大きく違った画像を生成した例です。 Trouserと、Dress、Coat は、比較的縦長に、白いpixcel が分布するので、誤認しやすいようです。(つまり、白いpixcel のマッチ度は、高い)

大枠はあってる

次の例は、同じ「靴」のカテゴリだけど、という例です。

f:id:bci-oshita:20190329122617p:plain
fashion-test-example-not-success_2

Context を見ると、足の甲のところをすごく重要視して、再現した結果Sandal っぽい感じの画像を生成したように見えます。

再現失敗!

f:id:bci-oshita:20190329123026p:plain
fashion-test-example-not-success_3

この例は、どのラベルに所属するのかが全くわからない画像を生成した例です。 Dress と Bag の中間概念の画像という感じがします。 Context画像だけみたら、確かにDress の可能性も否定できないですね。

なぜだかわからない

f:id:bci-oshita:20190329123127p:plain
fashion-test-example-not-success_4
最後に、なぜかはわかりませんが、Bagっぽい画像を生成した例です。 しいて言うなら、Context とのマッチ度は高い画像という印象がします。

やはりサンプル数(というかバリエーション数)が少なかった(100サンプル)ことが起因している、という印象です。

うまくいった例、うまくいかなかった例を見る限り、少ないサンプル数でも、大枠のカテゴリは一致する可能性は高い印象を受けました。 この印象は、普通のMNISTよりもFashion-MNIST の方が顕著に思います。 これは、自分が区別できないことが印象に影響があるかもしれませんが(笑)

Kuzushiji-MNIST(おまけ)

簡単に実験できるので、Kuzushiji についても、見てみます。 所感としては、正しいかどうか、判断ができない!という感じでした。orz

ラベル

Kuzushiji MNIST の10文字は、以下の通りです。(参考:rois-codh/kmnist

ラベル 文字
0
1
2
3
4
5
6
7
8
9

各文字の画像サンプルは以下の通り。(引用:rois-codh/kmnist

f:id:bci-oshita:20190329151249p:plain
kuzushiji-MNIST-10-letters

・・・、読めません!!(笑)

Trainset

うまくいった例

Trainset は、だいたいがうまく再現できている感じでした。当然かもしれませんが。

f:id:bci-oshita:20190329151703p:plain
kuzushiji-train-example-success

これは、比較的きれいに再現できている例です。

うまくいかなかった例

f:id:bci-oshita:20190329151808p:plain
kuzushiji-train-example-not-success

正直、Kuzushiji として、実は読めるのかどうか判断できていないですが、 Target画像の文字としては認識できないかな、ということで選びました。 ただ、Context画像、予測画像、Target画像を比べると、まぁまぁ近い画像かな、という感じです。 生成された画像(真ん中)だけをみたら、「を」とはわからないです。 ただ、「おきすつなはまやれを」のどれに似てるか? と問われれば、もしかしたら、「を」の可能性に気づけるかもしれませんね。。(^-^;

総じて、学習はできている印象を受けました。

Testset

では、testset についても見ていきます。

うまくいった例

f:id:bci-oshita:20190329152731p:plain
kuzushiji-test-example-success

これは、よく再現できました!という例です。 正直、Context画像(左)からは、「お」になるとは、わかりませんでした。 testset では、この例のように多少複雑な文字の生成で、成功している例は少なく、 うまくいっているのは、「つ」と「は」(というか「ハ」っぽい字)のすごくシンプルで、他の文字とは明確に違う文字の再現には成功しているようでした。これも、正規分布ベースであると考えると、そうだろうな、という印象です。

うまくいかなかった例

「つ」 v.s. 「す」

f:id:bci-oshita:20190329153103p:plain
kuzushiji-test-example-not-success

なるほど、確かに、Context 画像を見ると、「つ」の軌跡に沿った「す」になっています。

「な」 v.s. 「き」

f:id:bci-oshita:20190329154228p:plain
kuzushiji-test-example-not-success_2

これも、Context画像(左)を基準に考えると、似ているといえば似ている・・・かも、という感じです。

「参」?!

f:id:bci-oshita:20190329153302p:plain
kuzushiji-test-example-not-success_3

正直、Target画像(右)が何かわからなかったです。。 Kuzushiji-MNISTのサンプル画像で似た文字を探したら、なんと「つ」!! うそでしょ(笑) そして、生成された画像は、遠目から見ると「参」?!に見えて仕方がないです(笑) これは、多くの現代日本人でもわからないと思いますので、許してもらえる誤り、という気がしています(笑)

読めません!!

f:id:bci-oshita:20190329154426p:plain
kuzushiji-test-example-not-success_4

もう、どれも読めません!!(笑)

Kuzushiji のバリエーションがすごすぎて、漢字を含めて学習しているのか?!と思うぐらい、漢字っぽい画像を生成しています。ひらがなの大元が漢字ということを想起させる実験でした。

感想

モデル実装のときに、まったく学習が進まず、Toy Dataset に対する予測がtrainset ですら、平均0を描いていたりしました。 これは、モデルの学習Parameters(重み)がすべて取得できていなかったことが原因でした。 いつまでたっても、各層のWeight が更新されていないので、最初は勾配消失?でも、まったく変わらないのは、おかしいと思っていたら、そういうことでした。

そして、なにより、Kuzushiji、すごすぎ!(笑)

まとめ

  • Neural Processes は、Few-shot Learning ができる枠組みを持つ
  • ただし、あくまで機械学習・深層学習でできる範囲に留まる
    • 特殊な構造をとらえて、推定するわけではない
    • 各変数が独立した正規分布の枠組みで推定できる範囲に留まる
    • つまり、平均(代表的な像)を推定させるには使えそう
    • 実務的には、データ補完、疑似データ生成等に使えそう
      • これは、z をランダムサンプルする代わりに、[-1, 1] 等の区間を動かして生成する
    • Few-shot と呼べるほど、少量での学習は難しい様子
      • 実用性を考えると、やはりそれなりのバリエーションや、別のモデルが必要な印象
  • 実体(ラベル)は異なるが似た分布画像を生成する(誤りの)ケースに対応するために、判別タスクモデルと組み合わせたモデルを使うことで、より正確性を担保できそう
    • 実用性や解釈性(説明性)を考えると、このような複合モデルは有用そう
  • モデルのMLP部分を、CNN等別のサブモデルに変えてもよい
  • 個人的には、理論面でも実装面でも、勉強になるモデルであった
  • 自戒も込めて、ちゃんと、誤差逆伝播ができるようにモデルを実装することが大切です(笑)

参考リンク