あの論文を検証してみた! - シリーズ第1回 - BERT 論文(解説編)

はじめまして、ブレインズコンサルティングの大下です。

ブレインズコンサルティングでは、過去Blogger で、技術的な情報を公開していましたが、長らく更新が途絶えていたこともあり、 そちらを廃止し、こちらで、新たなテックブログとして開始することになりました。

記念すべき初回記事は、「あの論文を検証してみた!」のシリーズ第1回、今(2018年11月)、話題沸騰中(?)の 論文 [1810.04805] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding の解説です!

なにやら、複数の自然言語処理タスクでSOTAをたたき出して、すごいらしいということは、各種記事により、すぐわかったのですが、具体的にどういう仕組みですごいことができているのか、よくわからなかったので、「論文とGitHub のコードから探ってみよう!」というのが本記事執筆のモチベーションになっています。

そこで、本記事では、論文を読んで、実際に GitHub のコード (GitHub - google-research/bert: TensorFlow code and pre-trained models for BERT) を確認した結果を共有します。

検証環境

まずは、動作確認に使った検証環境を明記しておきます。

  • Ubuntu 16.04 VM / Win10
    • CPU: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz / 4 cores
    • Memory: 30GiB
    • HDD: 約300GiB (Available なサイズ)
  • Python 3.6.5
  • pycharm-community-2018.2.1
  • Tensorflow 1.12.0

環境構築手順

基本的には、GitHub (GitHub - google-research/bert: TensorFlow code and pre-trained models for BERT) の手順に従って実行します。 以下、参考までに環境構築手順のサマリを提示します。ただし、pyenv 等でpython 環境を構築していることを前提とします。

1.まずは、Google Research の GitHub をClone します。

git clone https://github.com/google-research/bert.git
cd bert

2.GLUE データセットをダウンロードします。

url="https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/becd574dd938f045ea5bd3cb77d1d506541b5345/download_glue_data.py"
wget $url
python download_glue_data.py --data_dir glue_data --tasks all   # around 2 minuites

3.pre-trained モデルをダウンロードします。

mkdir -p pre-trained
cd pre-trained

url_list="
https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip
https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip
"
for url in $url_list
do
    wget $url
    z=$(basename $url)
    unzip $z
done

以上で、最低限の環境構築が完了しました。

実行時の引数

基本的には、GitHub (GitHub - google-research/bert: TensorFlow code and pre-trained models for BERT) の引数をそのまま使っています。(i.e. MRPCデータセットを対象にしています。) 参考までに、pycharm のデバッグモードで、動作確認するときの引数例を以下に列挙します。

  • create_pretraining_data.py
--input_file=./sample_text.txt
--output_file=./result/tf_examples.tfrecord
--vocab_file=pre-trained/uncased_L-12_H-768_A-12/vocab.txt
--do_lower_case=True
--max_seq_length=128
--max_predictions_per_seq=20
--masked_lm_prob=0.15
--random_seed=12345
--dupe_factor=5
  • run_pretraining.py
--input_file=result/tf_examples.tfrecord
--output_dir=pretraining_output
--do_train=True
--do_eval=True
--bert_config_file=pre-trained/uncased_L-12_H-768_A-12/bert_config.json
--init_checkpoint=pre-trained/uncased_L-12_H-768_A-12/bert_model.ckpt
--train_batch_size=32
--max_seq_length=128
--max_predictions_per_seq=20
--num_train_steps=20
--num_warmup_steps=10
--learning_rate=2e-5
  • run_classifier.py
--task_name=MRPC
--do_train=true
--do_eval=true
--data_dir=glue_data/MRPC
--vocab_file=pre-trained/uncased_L-12_H-768_A-12/vocab.txt
--bert_config_file=pre-trained/uncased_L-12_H-768_A-12/bert_config.json
--init_checkpoint=pre-trained/uncased_L-12_H-768_A-12/bert_model.ckpt
--max_seq_length=128
--train_batch_size=32
--learning_rate=2e-5
--num_train_epochs=3.0
--output_dir=./result/mrpc_output/

BERTの概要 

BERTは、ELMo, Open AI GPT のような言語表現モデルで、「Bidirectional Encoder Representations from Transformers」の省略語です。 単語の埋め込み/ベクトル化といえば、word2vec や fastText が有名でしたが、BERT も、よしなにベクトル化してくれるモデルで、単語をサブトーク(※1)というレベルで、ベクトル化する手法の1つです。そして、今後デファクトスタンダードになるかもと、噂されるモデルでもあります。

(※1) 論文中では、split word pieces と表現され、コード上では、sub_token という変数に対応します。単語をさらに分離した、単語の部分文字列を指します。

BERTの特徴 / ポイント

BERTは、fine-tune 可能なモデルであり、MLM(Masked LM)と呼ばれる手法とNSP(Next Sentence Prediction)という手法でpre-train します。

コードでは、↓が対応します。(run_pretraining.py) f:id:bci-oshita:20181127161025p:plain

MLM: Masked Language Model

MLM とは、入力系列をtoken(sub_token) 単位でマスクし、マスクした部分(単語)を予測させる言語モデルのことです。   入力文のtoken(sub_token)系列の15%(== masked_lm_prob)をマスク対象とし、 そのマスク対象の単語の内80%の確率で、token(sub_token)を "[MASK]"に置き換え 10%の確率で、vocab辞書からランダムな単語を選択して入れ替え、 残り10%の確率で、変更しないようにします。コードでは、↓の青地部分が対応します。 f:id:bci-oshita:20181126200659p:plain

NSP: Next Sentence Prediction

NSP は、文間の関係性理解を訓練するために、2値化(0, 1 の 2つのラベルを教師データと)した、直後文の予測タスクでpre-train するモデルです。

モデルを構築するコード箇所は、↓の部分(関数)が対応します。(run_pretraining.py) f:id:bci-oshita:20181127161829p:plain

入力層と出力層のFull-Connect 1層(中間層なし)で、活性化関数なしの、すごくシンプルなネットワークを構築しています。 このシンプルなネットワークの出力(logits) から、cross entropy loss を算出するモデルになっています。

 埋め込み / ベクトル化

論文では、↓の図で表現されている箇所を、コードを追って解析してみます。 f:id:bci-oshita:20181126212825p:plain

まず、word_embedding をしている箇所で分かりやすいのは、embedding_lookup() です。 以下コードが、embedding_lookup() を呼び出している箇所です。(modeling.py) f:id:bci-oshita:20181126201146p:plain ここでは、TPUを使う場合は、one-hot から改めて埋め込む処理を行い、 CPU, GPU の場合は、pre-train 時に学習した単語ベクトル(sub_token単位のベクトル)を取得する動きになります。

次の処理は、embedding_postprocessor() を実行します。 ↓のコードが、embedding_postprocessor() を呼び出している箇所です。(modeling.py) f:id:bci-oshita:20181126214106p:plain

この処理では、文単位でのベクトル化(token_type_embeddings)と、 token(sub_token)単位での位置のベクトル化(position_embeddings)をone-hot ベクトルから埋め込みを行います。 特に、単語(sub_token)レベルのone-hot ベクトルとは異なり、文番号のone-hot は、高々2次元、位置のone-hot も高々128次元(==FLAGS.max_seq_length) と低次元のため、 毎回one-hot ベクトルから構築してもさほど問題にならない点に注意しておきます。 そして、埋め込んだ文ベクトルと、位置ベクトルを、単語ベクトルに加算することで、埋め込みベクトルとして完成させます。 (この加算は、文ベクトルと位置ベクトルを単語ベクトルと同じ \mathbb{R}^n空間に埋め込むように学習させる意味を持ちます。)

Transformer / Self-Attention

次は、論文で、↓の図で表現されている箇所を、コードを追って解析してみます。 f:id:bci-oshita:20181126212550p:plain

↑の論文の図に対応するのは、モデルの構築部分です。 該当コードは、transformer_model() ↓です。(modeling.py) f:id:bci-oshita:20181126214410p:plain

このコード部分では、transformer_model() を呼び出しているだけです。 特に、209行目~211行目を見ると、「hidden_size」(隠れ層のユニット数)、「num_hidden_layers」(隠れ層の数)、「num_attention_heads」(アテンションヘッドの数)を引数で渡していることから、transformer_model() で、よしなに処理していると想像がつきます。

では、実際に、transformer_model() のコードを見てみます。(modeling.py)

まずは、入力Tensor のshape を確認します。(適当にブレークポイントを設定して、内容を確認します。) f:id:bci-oshita:20181127150638p:plain

layer_input 変数の shape が、(4096, 768) になっています。これは、(batch_size * seq_length, hidden_size) を意味しています。 ここで、batch_size は、バッチサイズ(32)で、seq_lengthは、最大系列長(128)で、hidden_sizeは、隠れ層のユニット数(768)です。 これらは、起動引数(FLAGSの値)や、JSONファイル("pre-trained/uncased_L-12_H-768_A-12/bert_config.json" 等)の設定値で変更できるようになっています。

同じように、Transformer 1セル分の出力を見てみます。 f:id:bci-oshita:20181127151458p:plain

すると、入力と同じように、layer_output 変数の shapeが、(4096, 768) となっています。 実際に、注意深く、各Tensor のshapeを追っていくと、入力と同じく(batch_size * seq_length, hidden_size) を意味していることが分かります。

以上から、Transformer 1セル分の入出力は、以下のような構成になっていることがわかります。 f:id:bci-oshita:20181127151735p:plain

通常のモデル表現のように、batch_size を無視して入出力を整理すると以下のように表現できます。 f:id:bci-oshita:20181127171315p:plain

Transformer を積層したイメージ(transformer_model() で、構築したモデルのイメージ)は、以下のようになります。 f:id:bci-oshita:20181127152454p:plain

(注) 特に設定(bert_config.json)を変更しなければ、Transformer の積層数(num_hidden_layers) は、12層です。

以上から論文のようなモデルになっていることがわかります。 特に、系列ごとに処理しているとみなして、系列ごとにTransformer のセルを分割して表現すると論文の図になることに注意します。 つまり、系列が混ざることにより、双方向性(Bidirectionality)が実現されています。

未知語について

自然言語系の転移学習時に注意しておきたい未知語(=pre-train したときには存在しなかった単語)の扱いを、どうしているのか、気になったので調べてみました。 まずは、word_embeddings をする前の処理、すなわち各tokenをインデックスに変換している処理で、未知語を処理しているのでは、と仮説を立てて調べてみました。

実行用スクリプトの run_classifier.py の↓の部分からたどっていきます。 f:id:bci-oshita:20181126205252p:plain

しかし、実際にインデックスに変換する処理では、辞書で変換しているだけでした。 ↓ tokenization.py のコード部分で、item(sub_token に分離したピース)文字列をインデックスに変換しています。
f:id:bci-oshita:20181128140700p:plain

単純に考えると、この item が vocab に存在しなかったら(新しい単語だったら)、KeyError になるハズだけど、何もケアしていない。。しかし、天下のGoogle さんが何の考えもなしに危ないこと(エラーハンドルなしでよいという判断)をするハズがないので、そこには何か理由があるハズ!です。

ということで、文字列をインデックスに変換する前の処理:tokenize するときに、未知語判定してるのでは?と仮説を立てて、tokenize しているところを調べてみます。 ↓のコードが、tokenize している箇所の抜粋です。(tokenization.py)  f:id:bci-oshita:20181126205356p:plain

上の青地部分(とその周辺)に注目すると、文字列を後ろから1文字ずつ削っていき、vocab辞書(vocab.txt)に存在するピース(文字列のかけら)を cur_substr (== sub_token) として追加していきます。 ここで、292-293行目に、文字列の先頭ではないピースに対しては、"##" を先頭に付与して、substr (== sub_token) として扱っていることに注意しておきます。 つまり、vocab.txt の "##" で始まるピースをうまく定義してあげれば、fine-tuning や predict 時に、"[UNK]"(BERTにとっての未知語=UNKown)として扱われることも、エラーになることもないということです。これは、BERT論文だけを読んでいては、読み取れなかった(補完しきれなかった)部分になります。

蛇足ですが、文字の数が100を超える単語も、"[UNK]"として扱っているようです。 具体的には、281 行目で、self.unk_token == "[UNK]" をリストに追加している処理があり、その1ステップ前の280行目を見ると、文字の数が self.max_input_chars_per_word == 100 (WordpieceTokenizer クラス生成時のデフォルト値) を超える、という条件になっていることから、わかります。

当然、vocab 辞書にマッチするsub_token がない場合(分割した単語(token)を、sub_tokenで構成できない場合)は、以下のis_bad == True のケースに該当し、 ↓の305行目のように、"[UNK]" として扱われるため、埋め込みのためのインデックス変換でエラーになることがないとわかります。 ([UNK] は、pre-trained モデルに同梱されているいずれの vocab.txt にも、登録されています。) f:id:bci-oshita:20181126205609p:plain

sub_token について、個人的にわかりにくかったので、sub_token への分解について、試したコード例を↓に示します。

>>> import tokenization
>>> vocab_file = "pre-trained/uncased_L-12_H-768_A-12/vocab.txt"
>>> vocab = tokenization.load_vocab(vocab_file)
>>> wordpiece_tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)

vocab 辞書の単語ピースにマッチするケース
(うまく、辞書に存在する単語ピースの系列に分割できるケース)
>>> token = "oovae"
>>> wordpiece_tokenizer.tokenize(token)
['o', '##ova', '##e']

vocab 辞書のいずれにもマッチしないケース
(うまく、辞書に存在する単語ピースの系列に分割できないケース)
>>> token = "oovaeY0aixee"
>>> wordpiece_tokenizer.tokenize(token)
['[UNK]']
>>>

日本語対応

BERTのtokenizer は、↓のコードの通り中国語に対応しています。(tokenization.py の BasicTokenizer; 単語単位で区切る処理) f:id:bci-oshita:20181126210907p:plain しかし、すべての日本語の漢字には対応していないため、日本語を処理するには、tokenization を改造したり、vocab辞書をメンテする必要がありそうです。 (↓は、該当コード)  f:id:bci-oshita:20181126211210p:plain

また、日本語のtokenizer も、学習するようにBERTと組み合わせることで、これまでにない tokenizer のSOTA的なものもできそうな印象をうけました。 例えば、tokenizer とセットになりそうなvocab 辞書もうまく学習(更新)する方法を構築できれば、ベクトル化とtokenize を同時にできる良いモデルができそうな予感がします。

  

まとめ

  • MLM、NSPの2つの手法を使って pre-train を行っている
    • MLM: 入力系列をランダムにマスクし、マスクした部分を推定させる手法
    • NSP: 次の文を推定する手法
  • 単語(sub_token)ベクトル、文ベクトル、位置ベクトルを加算したベクトルを埋め込みベクトルとしている
  • Transformer 内の Self-Attention により、出力系列==入力系列を混ぜた新たな系列を生成している
  • 単語のピース(sub_token) をうまく定義すれば、未知語よりも意味があるベクトルの系列に変換できる
  • 日本語対応には、tokenizer と、vocab 辞書をうまく定義してあげる必要がありそう
  • tokenizer と 辞書を同時に学習(更新)できるようなモデルをBERTをベースに作ったら、日本語でもうまくいくかもしれない(感想)

 

参考リンク

[1810.04805] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

GitHub - google-research/bert: TensorFlow code and pre-trained models for BERT

Google AI Blog: Open Sourcing BERT: State-of-the-Art Pre-training for Natural Language Processing

BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding · Issue #959 · arXivTimes/arXivTimes · GitHub

汎用言語表現モデルBERTを日本語で動かす(PyTorch) - Qiita