よっしーの私的空間

機械学習を中心に興味のあることについて更新します

Keras分析結果の再現性確保について

Kerasの再現性確保方法についてまとめます。主にKerasの公式ドキュメントを参考にしていますが、一部注意点があるので、その点を中心に解説します。

1.環境

GPU:GeForce GTX 1070
Python:3.7.3
Keras:2.4.3
tensorflow-gpu:2.4.0

再現性の確保の仕方に影響を与えるか分からないですが、分析モデルのOptimizer(最適化アルゴリズム)はadamを使いました。

2.公式ドキュメント(Keras Documentation)の記述

抜粋すると以下のような内容でした。

import numpy as np
import tensorflow as tf
import random as rn

import os
os.environ['PYTHONHASHSEED'] = '0'

np.random.seed(42)
rn.seed(12345)

session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)

from keras import backend as K
tf.set_random_seed(1234)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)

(参考:FAQ - Keras Documentation
たしかに上記で再現性を確保できるのですが、少し注意する点があります。

3.注意点

注意点1:Tensorflowのバージョンの問題

Tensorflowバージョン2の場合は上記はサポートされておらず、バージョン1として実行してやる必要があります。そのまま実行すると「AttributeError: module 'tensorflow' has no attribute 'Session'」というエラーが出ます。
修正版のソースは以下の通りです。

import numpy as np
import tensorflow as tf
import random as rn

import os
os.environ['PYTHONHASHSEED'] = '0'

np.random.seed(42)
rn.seed(12345)

session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)

from keras import backend as K
tf.compat.v1.set_random_seed(1234)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
tf.compat.v1.keras.set_session(sess)

注意点2:Tensorflowセッションを都度開始・終了する必要がある

上記ではセッション開始は定義されているのですが、終了が定義されていません。終了定義せずにJupyterNotebook等で分析部分のセルを繰り返し実行すると、結果が再現しなくなってしまいますので注意が必要です。原因はおそらく、セッションを終了せずに分析処理を2回以上実行すると、1回目の分析の値が残存し、2回目以降の分析に影響してしまうためと思われます。
終了定義を追記した修正版ソースは以下の通りです。※モデル部分はコメントアウトして割愛しています。

import numpy as np
import tensorflow as tf
import random as rn

import os
os.environ['PYTHONHASHSEED'] = '0'

np.random.seed(42)
rn.seed(12345)

session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)

#セッション開始
from keras import backend as K
tf.compat.v1.set_random_seed(1234)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
tf.compat.v1.keras.set_session(sess)

'''
分析モデル
'''

#セッション終了
K.clear_session() 
【2021/3/21追記】

GPUを使用する場合、上記では不十分らしいです。以下にGPUを使用した場合の方法についてまとめました。
book-read-yoshi.hatenablog.com