Kerasの再現性確保方法についてまとめます。主にKerasの公式ドキュメントを参考にしていますが、一部注意点があるので、その点を中心に解説します。
1.環境
GPU:GeForce GTX 1070
Python:3.7.3
Keras:2.4.3
tensorflow-gpu:2.4.0
再現性の確保の仕方に影響を与えるか分からないですが、分析モデルのOptimizer(最適化アルゴリズム)はadamを使いました。
2.公式ドキュメント(Keras Documentation)の記述
抜粋すると以下のような内容でした。
import numpy as np import tensorflow as tf import random as rn import os os.environ['PYTHONHASHSEED'] = '0' np.random.seed(42) rn.seed(12345) session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) from keras import backend as K tf.set_random_seed(1234) sess = tf.Session(graph=tf.get_default_graph(), config=session_conf) K.set_session(sess)
(参考:FAQ - Keras Documentation)
たしかに上記で再現性を確保できるのですが、少し注意する点があります。
3.注意点
注意点1:Tensorflowのバージョンの問題
Tensorflowバージョン2の場合は上記はサポートされておらず、バージョン1として実行してやる必要があります。そのまま実行すると「AttributeError: module 'tensorflow' has no attribute 'Session'」というエラーが出ます。
修正版のソースは以下の通りです。
import numpy as np import tensorflow as tf import random as rn import os os.environ['PYTHONHASHSEED'] = '0' np.random.seed(42) rn.seed(12345) session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) from keras import backend as K tf.compat.v1.set_random_seed(1234) sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf) tf.compat.v1.keras.set_session(sess)
注意点2:Tensorflowセッションを都度開始・終了する必要がある
上記ではセッション開始は定義されているのですが、終了が定義されていません。終了定義せずにJupyterNotebook等で分析部分のセルを繰り返し実行すると、結果が再現しなくなってしまいますので注意が必要です。原因はおそらく、セッションを終了せずに分析処理を2回以上実行すると、1回目の分析の値が残存し、2回目以降の分析に影響してしまうためと思われます。
終了定義を追記した修正版ソースは以下の通りです。※モデル部分はコメントアウトして割愛しています。
import numpy as np import tensorflow as tf import random as rn import os os.environ['PYTHONHASHSEED'] = '0' np.random.seed(42) rn.seed(12345) session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) #セッション開始 from keras import backend as K tf.compat.v1.set_random_seed(1234) sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf) tf.compat.v1.keras.set_session(sess) ''' 分析モデル ''' #セッション終了 K.clear_session()
【2021/3/21追記】
GPUを使用する場合、上記では不十分らしいです。以下にGPUを使用した場合の方法についてまとめました。
book-read-yoshi.hatenablog.com