tensorflow が遅くなる話

masm11 です。

最近、tensorflow を使って機械学習しています。 tensorflow がだんだん遅くなることがあって、気づいたことがあるので、書いてみます。 ただし、以下は私の想像であることをはじめにお断りしておきます。

まず、

a = tf.Variable(...)
b = tf.placeholder(...)

を実行します。この時、以下のように、2つのオブジェクトができて、 a, b がそれぞれのオブジェクトを指します。

f:id:masm11:20180404104135p:plain

次に、

c = a * b

を実行します。この時点では、実際の掛け算は行われません。掛け算を表すオブジェクトから a, b への参照ができます。以下のようになります。

f:id:masm11:20180404104150p:plain

そして、placeholder に値を設定して c を評価します。

d = c.eval(session=sess, feed_dict={b:...})

この時、ようやく実際に掛け算が行われます。

機械学習をしている時、epoch ごとに精度を評価することもあると思います。

for epoch in range(epochs):
    # 学習する
    # ...
    # 評価する
    d = c.eval(session=sess, feed_dict={b:...})

これは問題はありません。

ここで、以下のように書き換えてみます。

for epoch in range(epochs):
    # 学習する
    # ...
    # 評価する
    c = a * b
    d = c.eval(session=sess, feed_dict={b:...})

一見、変数 c を直前に作っているだけで、何も問題はないように思えます。

しかし、

c = a * b

この部分は、実際に掛け算を行わず、掛け算を表すオブジェクトを作るので、 以下のようにたくさんのオブジェクトが作られてしまいます。

f:id:masm11:20180404104206p:plain

さて、ここまで、コードのイメージでしか説明しませんでしたので、 実際のコードで時間を測定してみます。

まずは、ループの前で c を作った場合:

$ cat fast.py
#!/usr/bin/env python

import tensorflow as tf

a = tf.constant(2, dtype=tf.float32)
b = tf.constant(3, dtype=tf.float32)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

c = a * b
for i in range(1000):
    c.eval(session=sess)
$ for i in `seq 5`; do time python fast.py; done

real    0m1.092s
user    0m1.092s
sys 0m0.200s

real    0m1.110s
user    0m1.104s
sys 0m0.208s

real    0m1.108s
user    0m1.116s
sys 0m0.196s

real    0m1.106s
user    0m1.100s
sys 0m0.208s

real    0m1.094s
user    0m1.080s
sys 0m0.216s
$ 

次に、ループの中で c を作った場合です:

$ cat slow.py
#!/usr/bin/env python

import tensorflow as tf

a = tf.constant(2, dtype=tf.float32)
b = tf.constant(3, dtype=tf.float32)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(1000):
    c = a * b
    c.eval(session=sess)
$ for i in `seq 5`; do time python slow.py; done

real    0m7.812s
user    0m7.808s
sys 0m0.196s

real    0m7.915s
user    0m7.896s
sys 0m0.208s

real    0m7.834s
user    0m7.824s
sys 0m0.200s

real    0m7.841s
user    0m7.812s
sys 0m0.224s

real    0m7.852s
user    0m7.812s
sys 0m0.232s
$ 

7倍も時間がかかっています。

以上の結果から、tensorflow のオブジェクト(式)は使い捨てにしない方が良さそうです。