DQN-Chainerちょっとだけ中身見た その2
ここから下の内容は僕のメモみたいなものです.最期まで読んでも得られるものは少ないです.
最終目標は報酬をいろいろと弄ることですが,ちょっとハードルが高くて断念しています.
RL_glueやALEのソースコードも載せていますが,ライセンスも確認しましたので大丈夫なはずです.
万が一間違ってたら教えてください.
dqn_classに実装されているメソッドは以下の6つ.
def agent_init(self, taskSpec):
def agent_start(self, observation):
def agent_step(self, reward, observation):
def agent_end(self, reward):
def agent_cleanup(self):
def agent_message(self, inMessage):
それぞれの役割は以下
agent_init……必要なものの定義.self.DQN = DQN_class()で,同一ファイル内のchainerで学習するためのクラスDQN_classを指定している.
agent_start……observation(多分ゲームからもらう情報)を受け取り,4*84*84のstateをself.stateに入れる.
action, Q_now = self.DQN.e_greedy(state_, self.epsilon)で,行動と現在のQ値を取ってくる.
return returnActionで,self.DQN.e_greedy(state_, self.epsilon)で決定したactionを返す
agent_step……agent_startと同じく,4*84*84のstateをself.stateに入れる.
DQNの学習はLearningとEvaluationに分かれていて,Learningのとき,Random行動を入れることがある.
action, Q_now = self.DQN.e_greedy(state_, self.epsilon)で,agent_startと同様に行動と現在のQ値を取ってくる.
self.DQN.stockExperienceとself.DQN.experienceReplayで,状態,行動,報酬,次状態などを保存,保存したものから学習(ネットワークの更新)を行う.ココが核.
Learningならtime+=1する
agent_end……多分,Episodeの最期に実行される.agent_stepから,Episodeの最期では必要のない処理を抜いたものっぽい.
agent_cleanup……ここはパスしている.
agent_message……いろんなメッセージを書く.experiment_ale.pyを実行したターミナルに表示される.
おそらくこれらのメソッドはrl_gllueで環境と実験とRLを繋いで学習を行う上で定義しなければならないものなのでしょう.
あとは,
if __name__ == "__main__": AgentLoader.loadAgent(dqn_agent())
でdqn_agent()を渡してやります.
AgentLoader.loadAgentとはどんなやつなのでしょう?
from rlglue.agent import AgentLoader as AgentLoader
ここを見る限り,rlglueの機能のようです.
使い方はexperiment_ale.pyを見れば分かるでしょうか.ちょっと見てみます.
print "\n\nDQN-ALE Experiment starting up!" RLGlue.RL_init() while learningEpisode < max_learningEpisode: # Evaluate model every 10 episodes if np.mod(whichEpisode, 10) == 0: print "Freeze learning for Evaluation" RLGlue.RL_agent_message("freeze learning") runEpisode(is_learning_episode=False) else: print "DQN is Learning" RLGlue.RL_agent_message("unfreeze learning") runEpisode(is_learning_episode=True)
ここは繰り返し処理の部分です.runEpisodeメソッドが主機能でしょう.
同じくexperiment_ale.pyの
def runEpisode(is_learning_episode): global whichEpisode, learningEpisode RLGlue.RL_episode(0) totalSteps = RLGlue.RL_num_steps() totalReward = RLGlue.RL_return() whichEpisode += 1 if is_learning_episode: learningEpisode += 1 print "Episode " + str(learningEpisode) + "\t " + str(totalSteps) + " steps \t" + str(totalReward) + " total reward\t " else: print "Evaluation ::\t " + str(totalSteps) + " steps \t" + str(totalReward) + " total reward\t "
ん? totalReward = RLGlue.RL_return()ってことは,rewardはどこか別から受け取っているようです.てっきりexperiment_ale.pyで定義しているのだと思っていました.
dqn_agent_nature.pyでもどこかから受け取っています.いったいどこで定義してるのでしょう?
aleから受け取った勝敗数とかが怪しいと思うのですが,とりあえずRL_glueのpython_codecを見に行ってみましょう.
報酬の合計を取ってくるRL_return()を見てみます.
def RL_return(): reward = 0.0 doCallWithNoParams(Network.kRLReturn) doStandardRecv(Network.kRLReturn) reward = network.getDouble() return reward
reward = network.getDouble()だそうです.ではnetwork.getDouble()とは?
26 import rlglue.network.Network as Network ... 63 network = Network.Network()
networkというのはrlglue.network.NetworkのNetwork()というやつらしいです.
network.pyを見に行きます.そこにgetDouble()があるはずです.
84 kDoubleSize = 8 ... 93 self.recvBuffer = StringIO.StringIO('') ... 145 def getDouble(self): 146 s = self.recvBuffer.read(kDoubleSize) 147 return struct.unpack("!d",s)[0]
うーん...つまり,getDouble()ではバッファーから読み取った文字sを数値化して返すということでしょうか?
じゃあバッファーというのは?
一つずつ追いかけていきましょう.
experiment_ale.pyにおいてまずはじめに実行されるRL_init()を見てみます.
101 def RL_init(): 102 forceConnection() 103 doCallWithNoParams(Network.kRLInit) 104 doStandardRecv(Network.kRLInit) 105 #Brian Tanner added 106 taskSpecResponse = network.getString() 107 return taskSpecResponse
104 doStandardRecv(Network.kRLInit)に注目します.
71 def doStandardRecv(state): 72 network.clearRecvBuffer() 73 recvSize = network.recv(8) - 8 74 75 glueState = network.getInt() 76 dataSize = network.getInt() 77 remaining = dataSize - recvSize 78 79 if remaining < 0: 80 remaining = 0 81 82 remainingReceived = network.recv(remaining) 83 84 # Already read the header, so discard it 85 network.getInt() 86 network.getInt() 87 88 if (glueState != state): 89 sys.stderr.write("Not synched with server. glueState = " + str(glueState) + " but s hould be " + str(state) + "\n") 90 sys.exit(1)
82 remainingReceived = network.recv(remaining)に注目します.network.recv()を見てみます.
119 def recv(self,size): 120 s = '' 121 while len(s) < size: 122 s += self.sock.recv(size - len(s)) 123 self.recvBuffer.write(s) 124 self.recvBuffer.seek(0) 125 return len(s)
122 s += self.sock.recv(size - len(s))に注目します.sockという名前的に,ソケットでしょうか.
僕はBSDソケットのことをよく知らないのですが,やはりaleから受け取っていると見てよさそうです.
28 import socket
とありました.ググります.
17.2. socket — 低レベルネットワークインターフェース — Python 2.7.x ドキュメント
このモジュールは、PythonでBSD ソケット(socket) インターフェースを利用するために使用します。最近のUnixシステム、Windows, Max OS X, BeOS, OS/2など、多くのプラットフォームで利用可能です。
やっぱりBSDソケットでした.RL_glueを動かした時の動きが見たこと無いなあと思ってたらこういうことだったんですね.勉強になりました.
64 network.connect(host,port)
もう一度experiment_ale.pyにおいてまずはじめに実行されるRL_init()を見てみます.
101 def RL_init(): 102 forceConnection() 103 doCallWithNoParams(Network.kRLInit) 104 doStandardRecv(Network.kRLInit) 105 #Brian Tanner added 106 taskSpecResponse = network.getString() 107 return taskSpecResponse
ここのforceConnection()でBSDソケット通信の初期化をしてると思います.
37 def forceConnection(): 38 global network 39 if network == None: 40 41 theSVNVersion=get_svn_codec_version() 42 theCodecVersion=get_codec_version() 43 44 host = Network.kLocalHost 45 port = Network.kDefaultPort 46 47 hostString = os.getenv("RLGLUE_HOST") 48 portString = os.getenv("RLGLUE_PORT") 49 50 if (hostString != None): 51 host = hostString 52 53 try: 54 port = int(portString) 55 except TypeError: 56 port = Network.kDefaultPort 57 58 print "RL-Glue Python Experiment Codec Version: "+theCodecVersion+" (Build "+theSVN Version+")" 59 print "\tConnecting to " + host + " on port " + str(port) + "..." 60 sys.stdout.flush() 61 62 63 network = Network.Network() 64 network.connect(host,port) 65 network.clearSendBuffer() 66 network.putInt(Network.kExperimentConnection) 67 network.putInt(0) 68 network.send()
64 network.connect(host,port)
ここだなあ
101 def connect(self, host=kLocalHost, port=kDefaultPort, retryTimeout=kRetryTimeout): 102 while self.sock == None: 103 try: 104 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 105 self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 106 self.sock.connect((host, port)) 107 except socket.error, msg: 108 self.sock = None 109 time.sleep(retryTimeout) 110 else: 111 break
ここでconnectの初期化を行ったみたいです.
いまだrewardへはたどり着けていません.おもったより長い道のりです.
ともあれ,network.recv()の
119 def recv(self,size): 120 s = '' 121 while len(s) < size: 122 s += self.sock.recv(size - len(s)) 123 self.recvBuffer.write(s) 124 self.recvBuffer.seek(0) 125 return len(s)
では,ソケットから受け取った文字列をwriteしていることがわかります.size-len(s)はなんなんでしょう?
sizeはどこかから受け取っていますが,len(s)はイミフです.上でs=''とやっているのだから,len(s)は必ず0になるんじゃないんでしょうか?
まあそれは置いといて,sock.recv(size - len(s))がrewardにつながるのかな? sizeの値によりそうですが.
recvを呼んでいるのが,RLglue.pyの
82 remainingReceived = network.recv(remaining)
でした.
remainingとは?
同じくRLglue.pyのdoStandardRecv(state):から
73 recvSize = network.recv(8) - 8 74 75 glueState = network.getInt() 76 dataSize = network.getInt() 77 remaining = dataSize - recvSize
remainingはdataSize - recvSizeです.
recvSizeはnetwork.recv(8) - 8です.
つまり,len(s)から8を引くことになります.
len(s)はself.sock.recv(8)を文字列化したもの文字数です.
sock.recv()とnetwork.recv()は似てますが全く別物です.
socketのrecvメソッドはこちら
len(s)-8というのはつまり,文字数から8を引くということです.だめだ.わけわかんなくなってきた.
dataSizeはnetwork.getInt()で取ってきます.
その直前のglueStateもnetwork.getInt()で取ってきています.同じものでしょうか? 見てみます.
83 kIntSize = 4 ... 141 def getInt(self): 142 s = self.recvBuffer.read(kIntSize) 143 return struct.unpack("!i",s)[0]
dataSizeは4のようです.
要するにremainingはsocket(8)-8-4でしょうか.
ギブアップです.
報酬の付近をいろいろと弄ってみたかった(たとえば勝ち負けではなくラリーが続くように学習するなど)のですが,ちょっと調べることが多くなりそうです.知ってる方がいたらぜひコメントください.お願いします.
まあぶっちゃけゲームをプレイさせるのはそんなに重要な事ではないので,今回は潔く諦めることにしましょう.