こんにちは!ブレインズコンサルティングの大下です。
今回は、「あの論文を検証してみた!」のシリーズ第2回、BERTの可視化実験を紹介します。 BERTの枠組みで学習したTransformer (Self-Attention) が、入力系列のどこを注目しているのか、を可視化し、解釈を試みます。
先月ぐらいに、政府が考えるAIの7原則が記事になっていました。 その中に、「企業に決定過程の説明責任」というものがあり、一部で話題になっていたと記憶しています(批判が多かった印象)。 日本の戦略を考えると、量で質をカバーする方法では、もはや米国、中国には叶わないということもありそうなので、 仮に少量でも、日本らしい?質(==説明責任による安心・安全)を担保して差異化を図りたい、という流れになるのかもしれません。
ということで、説明責任に繋がるといいなぁという願いを込めて、BERTのアテンションの可視化に挑戦します!
検証環境
まずは、動作確認に使った検証環境を明記しておきます。(前回と同じ環境です)
- Ubuntu 16.04 VM / Win10
- CPU: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz / 4 cores
- Memory: 30GiB
- HDD: 約300GiB (Available なサイズ)
- Python 3.6.5
- pycharm-community-2018.2.1
- Tensorflow 1.12.0
環境構築の補足
attention を可視化するために、seaborn を使っています。 seaborn を入れてないpyenv環境は、再構築する必要があります。
単純に、seaborn をインストールした後、tk ライブラリ関係でimport エラーになることがあります(私は、なりました T-T)。 そこで、以下のような手順で、Python 環境を準備します。
1. tk-dev をインストールします。
sudo apt-get install -y tk-dev
2.pyenv 環境を構築し直します。(例:pyenv の環境名:「lab」) ※ 環境を再構築しないと、tk-dev をインストールした後も、import エラーが継続します。
pyenv uninstall lab pyenv uninstall 3.6.5 pyenv install -s 3.6.5 pyenv virtualenv 3.6.5 lab
今回の実験のアウトプット
実験用コード:bert_visualizer.zip (約21KB)
- 詳しくは、bert_visualizer.zip/README.md を参照ください。
サンプル出力結果:attention.zip (約23MB)
- 今回の実験用コードの出力結果です。ダウンロードに多少時間がかかります、ご注意ください。
- 分析用Excelシートのサンプル:trainset-00000@attention_head-08.xlsx (約48KB)
BERTのアテンションの可視化
BERTの可視化は、世の中でも結構実施されているらしく、Twitter や こちらの記事などでも紹介されています。 これらの記事では、わかりやすいように、文章のどの単語(token)に注目しているか、End-End を意識して解説されています。 当記事では、まず純粋にアテンション自体の可視化を行い、BERTで使用されているTransformer(Self-Attention)のアテンションが何者か、の解釈から進めていきます。
当記事のBERTの可視化では、MRPCデータセットを使って行います。 MRPC は、2組の文が言い換え(同義文)か否かを判定するためのデータセットです。
アテンションの可視化概要
まずは、1つアテンションを可視化した結果を見てみます。
trainset の1番目のデータ(token への分割後の系列)を見てみます。
[CLS] am ##ro ##zi accused his brother , whom he called " the witness " , of deliberately di ##stor ##ting his evidence . [SEP] referring to him as only " the witness " , am ##ro ##zi accused his brother of deliberately di ##stor ##ting his evidence . [SEP]
[CLS] が先頭を表す記号で、[SEP] が1つの文の終わりを意味しています。 MRPCデータは、1データ(系列)が、2つの文で構成されています。 つまり、この2つの文が言い換え(同義文)か否かを推定するためのデータです。
これ↑は、12個のアテンションヘッドを可視化した図になります。 (アテンションヘッドの数は、設定ファイル「bert_config.json」の「num_hidden_layers」の値に対応します。)
アテンションヘッド可視化画像の見方
1つサンプル画像を例に、見方と解釈について概説します。
縦軸がSelf-Attention の入力系列で、横軸がSelf-Attention の出力系列(==入力系列)です。 一般的に、Transformer のアテンションヘッドは、入力系列の各token に対して、出力系列のどのtoken に注目すべきかをパターン化するデータ(実数値行列)であると解釈できます。 Self-Attention では、入力系列==出力系列なので、係り受けのような関係を学習・表現できるのではないかと理解しています。
文単位に注意するアテンションヘッド
では、1番最初(一番左上)のアテンションヘッドの可視化画像を見てみます。
真ん中あたり(実はよく見ると、[SEP])を境界に色が変わっています。 これは、MRPCの1データ(入力系列==出力系列)が2つの文を含んでいますが、同じ文のtokenに注意し、違う文のtoken には注意しない注意方法とわかります。
周辺単語に注意するアテンションヘッド
次に3番目(一番上の行の左から3番目)のアテンションヘッドの可視化画像を見てみます。
注意(濃い赤色)が、対角線上に並んでいます。 これは、入力系列の周辺単語(前後数単語・数token)に注意する注意方法とわかります。
文単位で、周辺単語に注意するアテンションヘッド
次は、一番下の行の一番左のアテンションヘッドの可視化画像を見てみます。 真ん中あたりでうっすら境界があり、比較的対角線を中心に色が濃くなっています。 これは、同じ文の中もまんべんなく弱く注意しつつ、周辺単語を強めに注意する注意方法とわかります。 先の2つのアテンションヘッドの注意方法を合わせたような注意方法とわかります。
このように、アテンションヘッドは、入力系列のどのtoken が出力系列のどの単語に影響を与えるかのパターンを学習しているとわかります。 単純に考えて、アテンションヘッドを増やすことは、注意するパターンが増えるので、精度に貢献することが容易に想像がつきます。
pre-trained と fine-tuned とアテンションとの関係
pre-train したモデルと、pre-train したモデルをfine-tune したモデルで、アテンションに違いがあるか見てみます。
pre-trained と fine-tuned モデルの精度
アテンションとの関係を見る前に、fine-tune によって、モデルの精度が上がっているかを確認しておきます。
まず、pre-trained モデルの混同行列です。(対角線の合計が正解数です。) すべて、0 (言い換えでない)と判定しています。このタスクを一切学習していないので、妥当といえば妥当です。
次に、fine-tuned モデルの混同行列を見てみます。 すばらしい!!1つを除いて、すべて正解です。
ということで、fine-tune により、モデル精度が上がっていることが確認できました。 実験で使った20サンプルに対する正解率は、19 / 20 = 0.95 = 95% とかなりの精度です。
trainset 3番目のデータ
無事、fine-tune によって、精度が上がったことがわかったので、アテンションとの関係を探ります。
まずは、MRPC の trainset の3番目のデータ
[CLS] they had published an advertisement on the internet on june 10 , offering the cargo for sale , he added . [SEP] on june 10 , the ship ' s owners had published an advertisement on the internet , offering the explosives for sale . [SEP]
で確認します。
この3番目のデータ(attention.zip の 「00002」ディレクトリ)に対しては、pre-trained モデルでは、「言い換えでない」(0)と判定していますが、 fine-tuned モデルでは、「言い換えである」(1)と判定しています。正解ラベルは、fine-tuned モデルの通り「言い換えである」(1)です。
↓まずは、pre-trained モデル(プレトレーニング後)のアテンションの可視化です。
↓次に、fine-tuned モデル(ファインチューニング後)のアテンションの可視化です。
この2つのアテンションの可視化を比較すると、ほとんど同じであることがわかります。 細かいことを言うと、入力系列の各tokenの注意確率を見ると、微妙に異なっています。
trainset 11番目のデータ
同じように、11番目のデータ(attention.zip の 「00010」ディレクトリ)に対しても見てみます。
[CLS] legislation making it harder for consumers to erase their debts in bankruptcy court won overwhelming house approval in march . [SEP] legislation making it harder for consumers to erase their debts in bankruptcy court won speedy , house approval in march and was endorsed by the white house . [SEP]
このデータは、pre-trained モデル、fine-tuned モデルの両方とも、「言い換えでない」(0)と判定しているデータです。 正解ラベルは、「言い換えでない」(0)です。
※ 先ほどの混同行列を思い出すと、pre-trained モデルでは、サンプリングした最初の20データすべてを固定的に「言い換え出ない」(0)と判定していました。
↓pre-trained モデルで学習したアテンションの可視化です。
↓そして、fine-tuned モデルで学習したアテンションの可視化です。
このデータに対しても、pre-trained モデルとfine-tuned モデルに大きな差が見られない結果になりました。 これは、fine-tune による学習効果がない、というよりは、pre-trained モデルですでに、アテンションのパターンを十分学習できていると考えられます。 (実際、fine-tuned モデルでは、言い換えか否かの判定の精度が上がっていたため、fine-tune による学習効果があることがわかります。) つまり、このことが、BERTが複数のモデルでSOTAをたたき出せている理由の一つになると考えられます。
また、trainset の 3番目のデータと11番目のデータの可視化を比べてみると、微妙に異なる点はありますが、 文単位の注意、周辺単語の注意、というような大枠のパターンは類似しているように見えます。 このことからも、アテンションヘッドがある程度の共通的な注意パターンを学習し、データに応じて多少のバリエーションを表現する能力を持っていると考えられます。
別の視点で見ると、判定の正誤の理由を探るためにデータごとの注意パターンを見るためには、 全データを跨いだ同一ヘッドに対する平均値(平均パターンをヘッドの代表パターンとみなす)との差分を見たり、 微分を使った可視化(Grad-CAMのような可視化)手法と組み合わせる必要がありそうです。
依存する単語の可視化
文単位の注意、周辺単語の注意は、直観的にもわかりやすいので、これ以上触れず、 若干、離れた単語を注意しているアテンションヘッドを探ってみます。 上記から、pre-trained モデルとfine-tuned モデルに大きなパターンの差がないことがわかったので、ここではpre-trained モデルのみに絞ってみていきます。
trainset の1番目のデータ(token への分割後の系列)を再掲します。
[CLS] am ##ro ##zi accused his brother , whom he called " the witness " , of deliberately di ##stor ##ting his evidence . [SEP] referring to him as only " the witness " , am ##ro ##zi accused his brother of deliberately di ##stor ##ting his evidence . [SEP]
このデータの文単位で、周辺単語に注意するアテンションヘッド(一番下の行の一番左のアテンションヘッド)を対象にします。
以下の表は、いくつかの単語(input_token)に絞って、注意確率が高い順に5つのtokenを選定した表です。
各セルの文字列のフォーマットは、「token{系列上のインデックス}(注意確率)」です。 「am」、「##ro」、「##zi」は、「Amrozi」という人名っぽい1単語のピースです。 これらを見てみると、「accused」、「his」、「brother」、「witness」とほぼ同じ単語群に注意していることがわかります。 「witness」以外は、動詞「accused」とその目的語「his brother」に、注意しているので、なんとなく納得がいきます。
input_token が「his」、「brother」については、「whom」を最も高い注意確率を持っている点もなんとなく納得がいき、 「whom」に対しても、「accused」にさかのぼって注意し、「brother」が最も高い注意確率である点も納得しやすいです。
こう見てみると、人間の理解は、文法構造による関係や単語の係り受け(統計的な関係)を注意の基準にしているのに対し、 BERT/Transformer でのアテンションでは、単語の係り受けに近い、統計的な関係を注意の基準にしているように見えます。 イメージ的には、文法を習っていない人の単語(token)の注意に似ているのでは、と想像します。
なので、文法構造を踏まえた注意を自然もしくは恣意的に学習させることができれば、 もう少し、人間にとって理解しやすい注意の可視化につながりそうです。
Tips
当記事の実験では、使いませんでしたが、BERTの可視化ツール(bertviz)があるようです(弊社メンバに教えてもらいました)。 Pytorch版のBERTを使った可視化ツールのようです。
- GitHub - jessevig/bertviz: BertViz: Visualize Attention in Transformer Models (BERT, GPT2, BART, etc.)
- https://colab.research.google.com/drive/1vlOJ1lhdujVjfH857hvYKIdKPTD9Kid8
↓Colaboratory のサンプル画像です。こんな感じで、 各レイヤ、各アテンションヘッドを動的に分析ができるので、デモや発表等で使うと盛り上がりそうな印象です。
感想
BERT のソースコード(Tensorflow)の解析に苦労しました。。 フレームワークのように抽象化されて、かつ、深いので、どこをどう直せばアテンションの確率を表示できるか、探るのに一苦労でした。 define and run なので、仕方ないとも言えますが。。
まとめ
- アテンションヘッドは、注意方法のパターンを学習する
- pre-trained モデルとfine-tuned モデルのアテンションの可視化パターンはほぼ同じ
- これは、pre-trained モデルが、すでに十分アテンションを学習できていることを意味する
- アテンションヘッド単位で大枠の注意方法(パターン)が決まるが、データによってバリエーションを表現する
- 判定の正誤の理由を説明するためには、工夫が必要
- 全アテンションの同一ヘッドに対する平均値との差分を可視化
- 微分を使った可視化(Grad-CAMのような可視化)手法と組み合わせて可視化、等
- 人が英語を理解するような注意とBERT/Transformerによるアテンションの注意とでは、少し違うが似ている点もある
- BERT/Transformerによるアテンションの注意は、統計的な単語(token)間の関係を学習しているように見える
- 解釈性を上げるには、文法構造を踏まえた注意を学習させる必要がありそう