読者です 読者をやめる 読者になる 読者になる

DQN-Chainerちょっとだけ中身見た その2

ここから下の内容は僕のメモみたいなものです.最期まで読んでも得られるものは少ないです.
最終目標は報酬をいろいろと弄ることですが,ちょっとハードルが高くて断念しています.
RL_glueやALEのソースコードも載せていますが,ライセンスも確認しましたので大丈夫なはずです.
万が一間違ってたら教えてください.


github.com

dqn_classに実装されているメソッドは以下の6つ.



def agent_init(self, taskSpec):

def agent_start(self, observation):

def agent_step(self, reward, observation):

def agent_end(self, reward):

def agent_cleanup(self):

def agent_message(self, inMessage):



それぞれの役割は以下

agent_init……必要なものの定義.self.DQN = DQN_class()で,同一ファイル内のchainerで学習するためのクラスDQN_classを指定している.



agent_start……observation(多分ゲームからもらう情報)を受け取り,4*84*84のstateをself.stateに入れる.

  action, Q_now = self.DQN.e_greedy(state_, self.epsilon)で,行動と現在のQ値を取ってくる.

  return returnActionで,self.DQN.e_greedy(state_, self.epsilon)で決定したactionを返す



agent_step……agent_startと同じく,4*84*84のstateをself.stateに入れる.

  DQNの学習はLearningとEvaluationに分かれていて,Learningのとき,Random行動を入れることがある.

  action, Q_now = self.DQN.e_greedy(state_, self.epsilon)で,agent_startと同様に行動と現在のQ値を取ってくる.

  self.DQN.stockExperienceとself.DQN.experienceReplayで,状態,行動,報酬,次状態などを保存,保存したものから学習(ネットワークの更新)を行う.ココが核.

  Learningならtime+=1する

agent_end……多分,Episodeの最期に実行される.agent_stepから,Episodeの最期では必要のない処理を抜いたものっぽい.

agent_cleanup……ここはパスしている.

agent_message……いろんなメッセージを書く.experiment_ale.pyを実行したターミナルに表示される.


おそらくこれらのメソッドはrl_gllueで環境と実験とRLを繋いで学習を行う上で定義しなければならないものなのでしょう.
あとは,

if __name__ == "__main__":
    AgentLoader.loadAgent(dqn_agent())

dqn_agent()を渡してやります.
AgentLoader.loadAgentとはどんなやつなのでしょう?

from rlglue.agent import AgentLoader as AgentLoader

ここを見る限り,rlglueの機能のようです.
使い方はexperiment_ale.pyを見れば分かるでしょうか.ちょっと見てみます.

print "\n\nDQN-ALE Experiment starting up!"
RLGlue.RL_init()

while learningEpisode < max_learningEpisode:
    # Evaluate model every 10 episodes
    if np.mod(whichEpisode, 10) == 0:
        print "Freeze learning for Evaluation"
        RLGlue.RL_agent_message("freeze learning")
        runEpisode(is_learning_episode=False)
    else:
        print "DQN is Learning"
        RLGlue.RL_agent_message("unfreeze learning")
        runEpisode(is_learning_episode=True)

ここは繰り返し処理の部分です.runEpisodeメソッドが主機能でしょう.
同じくexperiment_ale.pyの

def runEpisode(is_learning_episode):
    global whichEpisode, learningEpisode

    RLGlue.RL_episode(0)
    totalSteps = RLGlue.RL_num_steps()
    totalReward = RLGlue.RL_return()

    whichEpisode += 1

    if is_learning_episode:
        learningEpisode += 1
        print "Episode " + str(learningEpisode) + "\t " + str(totalSteps) + " steps \t" + str(totalReward) + " total reward\t "
    else:
        print "Evaluation ::\t " + str(totalSteps) + " steps \t" + str(totalReward) + " total reward\t "

ん? totalReward = RLGlue.RL_return()ってことは,rewardはどこか別から受け取っているようです.てっきりexperiment_ale.pyで定義しているのだと思っていました.
dqn_agent_nature.pyでもどこかから受け取っています.いったいどこで定義してるのでしょう?
aleから受け取った勝敗数とかが怪しいと思うのですが,とりあえずRL_glueのpython_codecを見に行ってみましょう.
報酬の合計を取ってくるRL_return()を見てみます.

def RL_return():
    reward = 0.0
    doCallWithNoParams(Network.kRLReturn)
    doStandardRecv(Network.kRLReturn)
    reward = network.getDouble()
    return reward

reward = network.getDouble()だそうです.ではnetwork.getDouble()とは?

26 import rlglue.network.Network as Network
...
63         network = Network.Network()

networkというのはrlglue.network.NetworkのNetwork()というやつらしいです.
network.pyを見に行きます.そこにgetDouble()があるはずです.

  84 kDoubleSize = 8
...
  93         self.recvBuffer = StringIO.StringIO('')
...
145     def getDouble(self):
146         s = self.recvBuffer.read(kDoubleSize)
147         return struct.unpack("!d",s)[0]

うーん...つまり,getDouble()ではバッファーから読み取った文字sを数値化して返すということでしょうか?
じゃあバッファーというのは?
一つずつ追いかけていきましょう.
experiment_ale.pyにおいてまずはじめに実行されるRL_init()を見てみます.

101 def RL_init():
102     forceConnection()
103     doCallWithNoParams(Network.kRLInit)
104     doStandardRecv(Network.kRLInit)
105     #Brian Tanner added
106     taskSpecResponse = network.getString()
107     return taskSpecResponse

104 doStandardRecv(Network.kRLInit)に注目します.

 71 def doStandardRecv(state):
 72     network.clearRecvBuffer()
 73     recvSize = network.recv(8) - 8
 74
 75     glueState = network.getInt()
 76     dataSize = network.getInt()
 77     remaining = dataSize - recvSize
 78
 79     if remaining < 0:
 80         remaining = 0
 81
 82     remainingReceived = network.recv(remaining)
 83
 84     # Already read the header, so discard it
 85     network.getInt()
 86     network.getInt()
 87
 88     if (glueState != state):
 89         sys.stderr.write("Not synched with server. glueState = " + str(glueState) + " but s    hould be " + str(state) + "\n")
 90         sys.exit(1)

82 remainingReceived = network.recv(remaining)に注目します.network.recv()を見てみます.

119     def recv(self,size):
120         s = ''
121         while len(s) < size:
122             s += self.sock.recv(size - len(s))
123         self.recvBuffer.write(s)
124         self.recvBuffer.seek(0)
125         return len(s)

122 s += self.sock.recv(size - len(s))に注目します.sockという名前的に,ソケットでしょうか.
僕はBSDソケットのことをよく知らないのですが,やはりaleから受け取っていると見てよさそうです.

 28 import socket

とありました.ググります.
17.2. socket — 低レベルネットワークインターフェース — Python 2.7.x ドキュメント

このモジュールは、PythonBSD ソケット(socket) インターフェースを利用するために使用します。最近のUnixシステム、Windows, Max OS X, BeOS, OS/2など、多くのプラットフォームで利用可能です。


やっぱりBSDソケットでした.RL_glueを動かした時の動きが見たこと無いなあと思ってたらこういうことだったんですね.勉強になりました.

 64         network.connect(host,port)

もう一度experiment_ale.pyにおいてまずはじめに実行されるRL_init()を見てみます.

101 def RL_init():
102     forceConnection()
103     doCallWithNoParams(Network.kRLInit)
104     doStandardRecv(Network.kRLInit)
105     #Brian Tanner added
106     taskSpecResponse = network.getString()
107     return taskSpecResponse

ここのforceConnection()でBSDソケット通信の初期化をしてると思います.

 37 def forceConnection():
 38     global network
 39     if network == None:
 40
 41         theSVNVersion=get_svn_codec_version()
 42         theCodecVersion=get_codec_version()
 43
 44         host = Network.kLocalHost
 45         port = Network.kDefaultPort
 46
 47         hostString = os.getenv("RLGLUE_HOST")
 48         portString = os.getenv("RLGLUE_PORT")
 49
 50         if (hostString != None):
 51             host = hostString
 52
 53         try:
 54             port = int(portString)
 55         except TypeError:
 56             port = Network.kDefaultPort
 57
 58         print "RL-Glue Python Experiment Codec Version: "+theCodecVersion+" (Build "+theSVN    Version+")"
 59         print "\tConnecting to " + host + " on port " + str(port) + "..."
 60         sys.stdout.flush()
 61
 62
 63         network = Network.Network()
 64         network.connect(host,port)
 65         network.clearSendBuffer()
 66         network.putInt(Network.kExperimentConnection)
 67         network.putInt(0)
 68         network.send()

64 network.connect(host,port)
ここだなあ

101     def connect(self, host=kLocalHost, port=kDefaultPort, retryTimeout=kRetryTimeout):
102         while self.sock == None:
103             try:
104                 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
105                 self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
106                 self.sock.connect((host, port))
107             except socket.error, msg:
108                 self.sock = None
109                 time.sleep(retryTimeout)
110             else:
111                 break

ここでconnectの初期化を行ったみたいです.
いまだrewardへはたどり着けていません.おもったより長い道のりです.
ともあれ,network.recv()の

119     def recv(self,size):
120         s = ''
121         while len(s) < size:
122             s += self.sock.recv(size - len(s))
123         self.recvBuffer.write(s)
124         self.recvBuffer.seek(0)
125         return len(s)

では,ソケットから受け取った文字列をwriteしていることがわかります.size-len(s)はなんなんでしょう?
sizeはどこかから受け取っていますが,len(s)はイミフです.上でs=''とやっているのだから,len(s)は必ず0になるんじゃないんでしょうか?
まあそれは置いといて,sock.recv(size - len(s))がrewardにつながるのかな? sizeの値によりそうですが.
recvを呼んでいるのが,RLglue.pyの

82     remainingReceived = network.recv(remaining)

でした.
remainingとは?
同じくRLglue.pyのdoStandardRecv(state):から

 73     recvSize = network.recv(8) - 8
 74
 75     glueState = network.getInt()
 76     dataSize = network.getInt()
 77     remaining = dataSize - recvSize

remainingはdataSize - recvSizeです.
recvSizeはnetwork.recv(8) - 8です.
つまり,len(s)から8を引くことになります.
len(s)はself.sock.recv(8)を文字列化したもの文字数です.
sock.recv()とnetwork.recv()は似てますが全く別物です.
socketのrecvメソッドはこちら


len(s)-8というのはつまり,文字数から8を引くということです.だめだ.わけわかんなくなってきた.
dataSizeはnetwork.getInt()で取ってきます.
その直前のglueStateもnetwork.getInt()で取ってきています.同じものでしょうか? 見てみます.

 83 kIntSize = 4
 ...
141     def getInt(self):
142         s = self.recvBuffer.read(kIntSize)
143         return struct.unpack("!i",s)[0]

dataSizeは4のようです.
要するにremainingはsocket(8)-8-4でしょうか.

ギブアップです.

報酬の付近をいろいろと弄ってみたかった(たとえば勝ち負けではなくラリーが続くように学習するなど)のですが,ちょっと調べることが多くなりそうです.知ってる方がいたらぜひコメントください.お願いします.
まあぶっちゃけゲームをプレイさせるのはそんなに重要な事ではないので,今回は潔く諦めることにしましょう.

次はDQN-chainerを参考に,自作の簡単なゲームをDQNで学習させてみたいと思います.