学生時代に頑張ったことが何もない

2014-10-21

CaffeでDeep Q-Networkを実装して深層強化学習してみた

概要

深層学習フレームワークCaffeを使って,Deep Q-Networkという深層強化学習アルゴリズムC++で実装して,Atari 2600のゲームをプレイさせてみました.

Deep Q-Network

Deep Q-Network(以下DQN)は,2013年のNIPSのDeep Learning Workshopの”Playing Atari with Deep Reinforcement Learning”という論文で提案されたアルゴリズムで,行動価値関数Q(s,a)を深層ニューラルネットワークにより近似するという,近年の深層学習の研究成果を強化学習に活かしたものです.Atari 2600のゲームに適用され,既存手法を圧倒するとともに一部のゲームでは人間のエキスパートを上回るスコアを達成しています.論文の著者らは今年Googleに買収されたDeepMindの研究者です.

NIPS2013読み会で自分が紹介した際のスライドがこちらになります.

他の方が作成したスライドもあります.

必要なもの

ソースコード

GitHubで公開しています.DQN-in-the-Caffe

ネットワークの構成

ネットワークの構成は元論文の通り,

  1. 入力層:84x84x4(ラスト4フレームのダウンサンプリング&グレイスケール化)
  2. 隠れ層1:8x8のフィルタx8(ストライド4)による畳込み+ReLU
  3. 隠れ層2:4x4のフィルタx16(ストライド2)による畳込み+ReLU
  4. 隠れ層3:fully-connectedノードx256+ReLU
  5. 出力層:fully-connectedノードx18(18種類のアクションそれぞれの行動価値)

としました.このネットワークを逆伝播により学習するためには,複数ある出力のうち1つの出力のみに対して誤差を計算する必要があるのですが,それを可能にするためにCaffeのELTWISEレイヤーを使い,1つの要素のみ1で残りは0であるようなベクトルネットワークの出力に掛け合わせることで望みの出力だけを取り出せるようにしています.Caffeのネットワーク表記でネットワーク全体を書くと下のようになりました.

パラメータの学習

パラメータの学習のためには,「状態{s_t}で行動{a_t}を選択したところ,報酬{r_t}を獲得し,次の状態が{s_{t+1}}であった」という状態遷移{(s_t,a_t,r_t,s_{t+1})}の経験をreplay memoryというメモリに保管していき,パラメータ更新の際にはそこからランダムサンプリングした一定数の遷移それぞれについて

{Q(s_j,a_j) ¥leftarrow r_j + ¥gamma max_{a’} Q(s_{j+1},a’)}

となるように勾配を計算した上で,まとめて更新を行うミニバッチ学習を行います.

論文ではここでRMSPropというパラメータ更新量の自動調節アルゴリズムを用いていますが,Caffeには今のところRMSPropは実装されておらず,その代わりAdaDeltaというRMSPropによく似たアルゴリズムをすでに実装してpull requestを投げている人がいたので,それを使いました.ただし,AdaDeltaをそのまま使用するとパラメータが発散してしまうことが多かったため,AdaDeltaによる更新量にさらに一定の係数(最初の100万イテレーションでは0.2,次の100万イテレーションでは0.02)を掛けて用いました(同じようなことをやっている?人).ミニバッチの大きさは元論文と同じ32,割引率{¥gamma}は元論文では示されていませんが0.95としました.

論文ではreplay memoryの容量は100万フレームでしたが,メモリの都合上,半分の50万で実験しました.

学習時間

実行環境は

です.CaffeはGPUモードで,さらにcuDNNを使いました.この構成でミニバッチ5万個の学習に45分ほどかかりましたが,元論文では5万個分を30分ほどで学習しているので,1.5倍ほど遅い結果となりました.

結果

D

上の動画はPongというゲームを200万イテレーション(およそ30時間)学習させた後のプレイ動画です.右の緑がDQNで,元論文と同じく各フレームごとに5%の確率で完全にランダムにアクションを選ぶようにしています.3回の試行のスコアがそれぞれ16,13,19と,元論文の20という平均スコアには達していませんが,元論文では1000万イテレーションの学習を行っているので,より学習が進めば同等のスコアが出せるかもしれません.

論文ではHuman Expertのスコアは-3とされていますが,現時点でもそれよりは大幅に上回っているので,DQNは人間より強いという結果が再現出来て何よりです.

yoshiyoshi 2014/12/11 20:54 はじめまして.大変興味深い記事を読ませていただきました,

自分でも試したいと思いまして,公開されているdqnブランチとソースコードをいただいたのですが,どうしてもソースコードのコンパイルが自力で通らないので質問させていただきます.

現在,こちらはubuntu環境を使っていて,ノーマルなcaffeとdqnブランチのcaffeのmake(cmake & make all)は通ることを確認しました.どちらも$HOME/caffe以下で行っています.ただ,まずdqnブランチの方ではruntestがどうもAdaDeltaの項で

[ RUN ] AdaDeltaSolverTest/0.TestAdaDeltaLeastSquaresUpdateWithWeightDecay
F1211 07:26:56.105155 10802 solver.cpp:390] Unknown learning rate policy:(以下略)

とエラーメッセージが表示されて停止してしまいます.これはこういうものなのでしょうか?それとも他に解決すべき点があるのでしょうか?

muupanmuupan 2014/12/12 20:42 dqnブランチではAdaDeltaSolverの実装に変更を加えていますが,テストの方は今のところ手を付けていないのでそのようなエラーが出ます.本来は直すべきですが,とりあえずはそういうものとして無視して進めて頂いて問題無いと思います.

yoshiyoshi 2014/12/13 06:09 お返事有り難うございます.

無事CPUモードでコンパイルが通りました.
ひとまずこれを走らせて,のんびり結果を待つことにします.

ちなみにGPUモードでコンパイルしてrunするとsegmentation faultで停止してしまいました.
似たようなissueを見つけたので(https://github.com/BVLC/caffe/issues/678),
おそらくこちらがショボいGPU(Geforce GTX 660)で試しているためだと思います.

もし何かsuggestionありましたら教えていただければ幸いです.

okogeokoge 2014/12/20 13:11 初歩的な質問で申し訳ないのですが、
const_cast<float*>(frames_input.data())
にてなぜconst_castを行うのでしょうか?
この一つ前の記事ではconst_castを行ってはいませんでしたよね?
入力データを二次元配列にするにはconst_castしなければいけないと言った感じなんでしょうか?

okogeokoge 2014/12/30 11:30 ↑凄まじく馬鹿なことを書いてるのに気付きました。申し訳ない。

aizaiz 2016/03/04 14:55 とんでもなく貴重なサイトで感銘をうけています。
私も是非、自分で動かしてみたいのですが、初歩的質問で申し訳ありません。
現在ubuntu14.04でCaffeが動く状態までにはなりました。
GitHubから貴プログラムソースをダウンロードしてきて解凍するとdqn-masterというホルダーができます。この直下にCaffeのホルダーを移動させ、cmakeとするという理解でよろしいでしょうか。
前の質問で、dqnブランチの話がでていますが、意味がわかりません。
初歩的質問で誠に申し訳ありません。

aizaiz 2016/03/05 09:54 GttHubに関し全く無理解でした。勉強し出直します。申し訳ありません。

スパム対策のためのダミーです。もし見えても何も入力しないでください
ゲスト


画像認証

トラックバック - http://d.hatena.ne.jp/muupan/20141021/1413850461
Connection: close