あの論文を検証してみた! - シリーズ第3回 - Neural ODE 論文

こんにちは!ブレインズコンサルティングの大下です。

今回は、「あの論文を検証してみた!」のシリーズ第3回、Neural ODE の論文について解説、検証します。 今回の論文は、Neural Ordinary Differential Equationsで、ResNet と、オイラー法の更新則の類似性に着目し、 連続時間のモデルへ拡張した、新しい考え方、手法を提案した論文です。 個人的には、かなり内容が濃い論文の印象で、勉強になった論文です。

検証環境

まずは、動作確認に使った検証環境を明記しておきます。(前回から環境を変更しています。)

  • Ubuntu 18.04.1 LTS (Bionic Beaver)
    • CPU: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz / 4 cores / 8 processors
    • Memory: 48GiB
    • HDD: 約22GiB (Available なサイズ)
  • Python 3.6.5
  • pycharm-community-2018.3.3
  • pytorch 1.0.0
  • Colaboratory(CPU)

本検証におけるアウトプット

Colabratory notebook

概要

今回対象とする論文 では、ResNetやRNNを離散時間ステップのモデル化として捉え、 更新式を常微分方程式(ODE)に拡張することで、連続時間ステップのモデル化を実現するという新しい手法が提案されています。

深層学習フレームワークで扱えるように、効率的な自動微分の方法(adjoint method)が提案されいる点にも注目したいところです。 モデルが連続時間ステップに拡張できることで、任意の粒度での時系列予測や、欠損値が多いデータを用いた学習にも効果が発揮されると期待できます。

NF:Normalizing Flow を連続化したモデル(CNF: Continuous Normalizing Flow)についても触れられていますが、今回の記事では対象外といたします。

ResNet と ODENet

この論文のアイデアは、ResNet の更新式を拡張し、ODE(常微分方程式)を用いたネットワークへと拡張することです。 順を追ってみていきます。

ResNet の更新式は、隠れ層ベクトル h_tを用いると

\begin{align} h_{t+1} = h_t + f(h_t, \theta_t) \end{align}

と書けます。ただし、 t \in \{0, 1, ..., T\}です。

この式を、微分小に拡張すると、

\begin{align} \frac{dh(t)}{dt} = f(h(t), t, \theta) \end{align}

と表現できます。

この更新式に従うモデルを、ODENet と呼ぶことにします。 特に、 f は、ブラックボックスでよく、ニューラルネットワークで実現することを想定します。

ここで、ODENet のパラメータ \thetaは、時刻 tに依存しないパラメータであることに注意します。 つまり、ResNet では、時刻 t毎(Depth/層毎)にパラメータを保持し、更新する必要がありましたが、 このモデルでは、時刻 tを跨いでパラメータを共有することに注意します。

以下は、ResNet と ODENet のネットワーク構造の比較イメージです。(論文より抜粋)

f:id:bci-oshita:20190206132833p:plain
resnet v.s. odenet

ResNet は、単位時刻ステップと各隠れ層(Depth)が対応しています(離散時点)が、 ODENet では、層という概念がなくなり、任意の時点にまばらにユニットが配置されるような構造(連続時点)になっています。 (実際、githubのexamples/ode_demo.py は、vector field の可視化サンプルになっています。)

ネットワーク構造を比較してみます。

f:id:bci-oshita:20190207143334p:plain
resblock v.s. odeblock

ODENet(ODEBlock)では、時刻系列の定義によって、任意の時間間隔をユニットに持つネットワークを動的に定義できることに注意します。

自動微分への応用

論文における「2 Reverse-mode automatic differentiation of ODE solutions」で、 自動微分への応用方法について記述されています。 提案手法は、ODE ソルバーをブラックボックスとして扱うため、任意のODE ソルバーに適用可能です。 実際、github の examples では、複数のODE ソルバーを選択できるようになっています。 更新に必要な勾配の算出には、「adjoint sensitivity method」という効率的な手法を用います。

ニューラルネットワークに、ODENet を含めたときの学習(パラメータ更新)では、ODENet 内のパラメータ \thetaを更新する必要があります。 具体的には、以下のようなLoss関数を最適化することを考えます。

\begin{align} L(z(t_1)) = L(z(t_0) + \int_{t_0}^{t_1} f(z(t), t, \theta)dt) = L(ODESolve(z(t_0), f, t_0, t_1, \theta)) \end{align}

ODENet のパラメータ \thetaを更新するために、 \frac{dL}{d\theta} を計算できるようにする必要があります。

まず、計算を効率的に行うための工夫として、adjoint と呼ばれるベクトル  a(t)を以下のように定義します。 \begin{align} a(t) := \frac{dL}{dz(t)} \end{align}

(注意) いくつかの解説サイトでは、 a(t) := - \frac{dL}{dz(t)}となぜかマイナスをつけて定義していたりしますが、正しくは論文の通りで、マイナスを付けない定義が正しいです。(論文のAppendix を読めば判断できます。)

このadjoint ベクトル a(t)を用いると、 \frac{dL}{d\theta} は、以下のように記述できます。 (式の導出は、同論文のAppendix を参照ください。ただし、テイラー展開は、 z(t)の周りと記載されていますが、正確には tの周りで展開することに注意します。)

\begin{align} \frac{dL}{d\theta} = - \int_{t_1}^{t_0} a(t)^{\mathsf{T}} \frac{\partial f(z(t), t, \theta)}{\partial \theta} dt \end{align}

つまり、 a(t) z(t) \frac{\partial f}{\partial \theta}を各時点 tに対して、算出できれば、 \frac{dL}{d\theta}を計算できます(積分≒和で近似すればよい)。

 a(t)は、以下の微分方程式(導出方法は、同論文のAppendix 参照)

\begin{align} \frac{d a(t)}{dt} = -a(t)^{\mathsf{T}} \frac{\partial f(z(t), t, \theta)}{\partial z} \end{align}

と、初期値  a(t_1)=\frac{\partial L}{\partial z(t_1)}からODE ソルバーを用いて時系列の逆方向に計算していきます。 (初期値 a(t_1)は、プログラム上では、adjoint.py の OdeintAdjointMethod クラスのbackward メソッドの引数grad_output に対応します。)

同じように、 z(t)は、先のLoss関数の引数内の式( ODESolve(z(t_0), f, t_0, t_1, \theta))のようにODEソルバーを用いて、 z(t) t_1から遡って算出します。

adjoint 法の入出力関係を、以下の図に記載しておきます。

f:id:bci-oshita:20190206212250p:plain
full adjoint sensitivities algorithm input & output

 a(t) z(t)が算出できれば、 -a(t)\cdot \frac{\partial f}{\partial \theta} が算出でき、 時刻 tに対して積分(和)をとることによって、 \frac{\partial L}{\partial \theta}を算出することができ、 無事、パラメータ \theta偏微分係数(つまり、更新式)を得ることができます。

参考として、論文におけるアルゴリズムを載せておきます。

f:id:bci-oshita:20190206193809p:plain
full adjoint sensitivities algorithm

論文、コード対応

githubのコード(特に、adjoint.py)を読むときに、 論文で使われている記号と異なる変数で実装されていたため、非常に解読しにくかったです。 そこで、参考までに、コードの変数が、論文での表記ではどうなるかの対応を表にしておきます。

code paper
adj_y  a(T)
adj_params  a_{\theta}(T) := \mathbf{0}
adj_times  a_t(T) := 0
y, ans_i  z(t)
grad_output_i  a(t_i)
ODEfunc, func  f
func_i  f(z(t_i), t_i)
dLd_cur_t  \frac{dL}{dt}(t_i)
vjp_y  \frac{da}{dt}(t_{i-1})
vjp_t  \frac{da_t}{dt}(〃)
vjp_params  \frac{d a_{\theta}}{dt}(〃)

vjp は、Vector Jacobian Product の省略単語です。

ただし、

\begin{align} a(t) := \frac{dL}{dz}(t) \\ a_t(t) := \frac{dL}{dt}(t) \\ a_{\theta}(t) := \frac{dL}{d\theta}(t) \\ \end{align}

とする。

各adjoint  a_*は、以下の微分方程式を満たす。(導出は、同論文のAppendix 参照)

\begin{align} \frac{da}{dt} = -a\frac{\partial f}{\partial z} \\ \frac{da_t}{dt} = -a\frac{\partial f}{\partial t} \\ \frac{da_\theta}{dt} = -a\frac{\partial f}{\partial \theta} \\ \end{align}

損失関数L の時刻 t_iに対する微分係数は、以下の式から算出する。(導出は、同論文のAppendix 参照) \begin{align} \frac{dL}{dt}(t_i) = -a(t_i) \cdot f(y(t_i), t_i, \theta)\\ \end{align}

例: MNIST

この自動微分を使った例(コード)が、githubexamples/odenet_mnist.py です。

実行例は、以下の通り。

python odenet_mnist.py --network="odenet" --nepochs=5

↓実行結果例(2エポック目で、正答率が98%)

:
Number of parameters: 208266
Epoch 0000 | Time 3.119 (3.119) | NFE-F 32.0 | NFE-B 0.0 | Train Acc 0.0982 | Test Acc 0.0949
Epoch 0001 | Time 2.060 (1.819) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9809 | Test Acc 0.9820
Epoch 0002 | Time 2.116 (1.877) | NFE-F 21.0 | NFE-B 0.0 | Train Acc 0.9859 | Test Acc 0.9855
Epoch 0003 | Time 2.631 (2.091) | NFE-F 23.8 | NFE-B 0.0 | Train Acc 0.9900 | Test Acc 0.9902
Epoch 0004 | Time 2.630 (2.093) | NFE-F 23.9 | NFE-B 0.0 | Train Acc 0.9851 | Test Acc 0.9840

CPUだと結構時間がかかります。(当検証環境において、約2時間) 個人的な感想ですが、MNISTだとResNet 相当のモデルは不要なので、 MNISTレベルの簡単なタスクは、このモデルを使う必要がないのでは、という印象でした。

実験

当記事では、ODERNNモデル(examples参照)で、時系列データの予測を試みました。 論文では、Toyデータセットを使っているので、一般的なデータでどれぐらい予測がうまくいくか検証してみます。 使用したデータセットは、Daily Demand Forecasting Orders Data Set です。

今回、検証で使ったモデルの概観は、以下の通りです。

f:id:bci-oshita:20190207150004p:plain
odernn

論文では、以下の図が対応します。

f:id:bci-oshita:20190208121404p:plain
latent ode model

実際、実験した結果は、以下の通りです。

f:id:bci-oshita:20190208162008p:plain
odernn prediction

actual (observed) が、train set に相当し、actual (unobserved) が、validation set に相当します。 このモデルは、正規分布のVAEベースで潜在変数を推定していますので、波形の平均線をなぞるような軌跡を描くように学習します。 actual (unobserved) (青線)部分の予測(extrapolation (future) の最初の方)についても、概ね、平均をなぞっているように見え、違和感がない結果になっています。

このモデルが他のモデルと違う点は、観測したデータ(train set)の過去に遡って、推定(逆方向に外挿)することができる点です。 (グラフでは、extrapolation (past) に相当します。)

ただ、外挿部分(extrapolation)は、どちらの方向も、0 につぶれていっており、人が考える自然な予測軌道とは異なる結果になっています。 軌道の調整については、軌道数、軌道のサンプル数を調整(グリッドサーチ等)することで、ある程度最適化できると想定しています。

この実験において、Colaboratory のバックエンドの状況によるかもしれないですが、GPU で学習するより、CPUで学習した方が速い印象でした。 また、潜在変数の次元(latent_dim)を大きくすると、学習に時間がかかる点にも注意します。 ODEの右辺関数 f (ニューラルネットワーク) の複雑度(表現力)、学習状況にもよるかもしれませんが、比較的、発散しやすいと感じられました。

実用としては、データ数が2年間よりも少なく(季節性を判断できなく)ても、軌道(微分導関数)を学習するため、直近未来の予測には使えそうな印象です。 ただ、学習毎の予測軌道を安定させるための工夫(予測軌道の平均をとる等)が必要と感じました。

感想

実際に、コードをカスタマイズして検証することで、理解していると思っていた内容を誤解していたことに気づけました。 手を動かすのは、本当に重要ですね。

まとめ

  • ODENet は、ResNet を参考に、離散時間から連続時間へ拡張したモデル
  • ODENet の学習(自動微分)を効率化するための、adjoint 法が提案されている
    • adjoint 法を使った学習の例は、MNISTのサンプルコードを参照
    • MNISTレベルのタスクでは、オーバースペックの印象
  • ODENet をうまく使えば、データ補完、データ拡張に使えそう
  • 実験により、ODERNN は、季節性を計算できないデータ、少量データに対しても、直近未来の予測には使えそう
    • ただし、実用上は、安定した予測軌道を得るための工夫が必要
    • 過去方向への外挿ができる点は、他の時系列モデルや手法にはなかった印象

参考リンク