TensorFlowのRNN実装はサンプルが少なく、
かつそういったサンプルコードでは、
限定された一部のAPIしか使っていないなど全体を網羅しづらい感じがあるので、
なるべく全体感を思い出しやすいように、自分用にメモ。
(と言う割に基本的なAPIしか使ってないが...)
TensorFlowのRNNのAPIを使わないRNNの実装
まずは2(タイム)ステップだけ動作するRNNをつくる。
各ステップで入れるデータをX0, X1としてplaceholderをつくる。
各ステップでは、入力ユニット数(=4)にあたるデータを用意。
今回は隠れ層は1つとして、
唯一の隠れ層のユニット数は6。
n_inputs = 4 n_neurons = 6 X0 = tf.placeholder(tf.float32, [None, n_inputs]) X1 = tf.placeholder(tf.float32, [None, n_inputs]) Wx = tf.Variable(tf.random_normal(shape=[n_inputs, n_neurons],dtype=tf.float32)) Wy = tf.Variable(tf.random_normal(shape=[n_neurons,n_neurons],dtype=tf.float32)) b = tf.Variable(tf.zeros([1, n_neurons], dtype=tf.float32)) Y0 = tf.tanh(tf.matmul(X0, Wx) + b) Y1 = tf.tanh(tf.matmul(Y0, Wy) + tf.matmul(X1, Wx) + b) init = tf.global_variables_initializer()
ミニバッチは4つデータを含む。
# t = 0 X0_batch = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) # t = 1 X1_batch = np.array([[15, 14, 13, 12], [0, 0, 0, 0], [8, 7, 6, 5], [4, 3, 2, 1]]) with tf.Session() as sess: init.run() Y0_eval, Y1_eval = sess.run([Y0, Y1], feed_dict={X0: X0_batch, X1: X1_batch}) print(Y0_eval) print(Y1_eval)
tf.contrib.rnn.static_rnn()を使う実装
上記実装はシンプルだが、
ステップ数を増やそうとするとネットワークが複雑になっていく。
そこでTensorFlowのstatic_rnnというAPIを使う。
static_rnnを使うと、TensorFlowで用意されたRNNのcellをrollingせずに記述できる。
cell(下ではBasicRNNCellを使用)というのは、各ステップに対応するRNNユニットと考えるとよい。
static_rnnで、それらのcellと各ステップの入力をAPI内で繋げてくれる。
具体的なstatic_rnnの内部の処理としては、
入力があるごとにcellをコピーして、上の実装のように繋げているようだ。
n_inputs = 4 n_neurons = 6 X0 = tf.placeholder(tf.float32, [None, n_inputs]) X1 = tf.placeholder(tf.float32, [None, n_inputs]) basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons) output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell, [X0, X1], dtype=tf.float32) Y0, Y1 = output_seqs init = tf.global_variables_initializer()
# t = 0 X0_batch = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) # t = 1 X1_batch = np.array([[15, 14, 13, 12], [0, 0, 0, 0], [8, 7, 6, 5], [4, 3, 2, 1]]) with tf.Session() as sess: init.run() Y0_eval, Y1_eval = sess.run([Y0, Y1], feed_dict={X0: X0_batch, X1: X1_batch}) print(Y0_eval) print(Y1_eval)
各ステップの入力をまとめて1つとして扱う
上の実装では、2ステップ分の入力のplaceholderをそれぞれX0, X1と定義していた。
これもステップ数が多くなるとX0, X1, X2... と増やさないといけなくなるので、
入力をまとめて扱うようにする。
具体的には、
[バッチサイズ、ステップ数、入力ユニット数]としたplaceholderを用意してまとめて扱い、
それらを各ステップの入力ごとにバラす
(tf.unstackで1次元目に沿って、別々のリストに分ける)。
出力のところでは逆にtf.stackを使って、
各ステップでの出力結果を結合して、
[バッチサイズ,ステップ数,出力ユニット数]を得ている。
n_steps = 2 n_inputs = 4 n_neurons = 6 X = tf.placeholder(tf.float32, [None, n_steps, n_inputs]) X_seqs = tf.unstack(tf.transpose(X, perm=[1, 0, 2])) basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons) output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell, X_seqs, dtype=tf.float32) outputs = tf.transpose(tf.stack(output_seqs), perm=[1, 0, 2]) init = tf.global_variables_initializer()
X_batch = np.array([ # t = 0 t = 1 [[0, 1, 2, 3], [15, 14, 13, 12]], [[4, 5, 6, 7], [0, 0, 0, 0]], [[8, 9, 10, 11], [8, 7, 6, 5]], [[12, 13, 14, 15], [4, 3, 2, 1]], ]) with tf.Session() as sess: init.run() outputs_val = outputs.eval(feed_dict={X: X_batch})
dynamic_rnnを使う
上の実装にはまだ問題がある。
逆伝搬時の勾配計算のために、
順伝搬時の各ステップの計算結果を保持して置く必要があるため、
メモリを圧迫してしまう。
これを避けるためにstatic_rnnの代わりにdynamic_rnnを使うとよい。
dynamic_rnnに渡す引数としてswap_memoryをTrueにしておけば、
GPUからCPUにメモリ退避してくれる。
また、別の利点として各ステップの入出力に、
[バッチサイズ,ステップ数,ユニット数]をそのまま使えるので、
上の実装にあるようにunstackやstackをしなくてよくなる。
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
ステップごとの入力を可変長とする
ここまでは各ステップでの入力長は固定だった。
でも文字列みたいに可変長のステップを入れないといけない場合もある。
そういう場合には、dynamic_rnnやstatic_rnnを呼ぶときに、sequence_lengthを設定する。
seq_length = tf.placeholder(tf.int32, [None])
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32,
sequence_length=seq_length)
設定したsquence_lengthに満たない入力は0でpaddingして入力しなければならない
X_batch = np.array([ # step 0 step 1 [[0, 1, 2], [9, 8, 7]], # [[3, 4, 5], [0, 0, 0]], # 0でパディング [[6, 7, 8], [6, 5, 4]], # [[9, 0, 1], [3, 2, 1]], # ]) seq_length_batch = np.array([2, 1, 2, 2])
その他
いまBasicRNNCellをつかっているが、
cellはLSTMやGRUに入れ替えることも可能。