よっしーの私的空間

機械学習を中心に興味のあることについて更新します

TensorflowによるBiT(Big Transfer)の実装

2019年にGoogle Brainから発表された画像認識モデルBiT(Big Transfer)をファインチューニングする方法についてまとめます。BiTにかかわる解説は以下が良くまとまっていました。
パラメータ数10億!最新の巨大画像認識モデル「BiT」爆誕 & 解説 - Qiita

簡単に特徴をまとめると…

  • Efficientnetよりも様々なデータセットで良い精度を出している。
  • 超巨大なモデルで、膨大な画像を事前学習した脳筋モデル。
  • 前処理方法や学習時のパラメータ等がBiT Hyper-Ruleというルールで規定されており、BiT Hyper-Ruleに則って実装することでハイパーパラメータチューニングの手間が省ける。BiT Hyper-Ruleは実装に大きく関わるので重要

手っ取り早くソースコードを見たい方は本記事最下段の参考まで飛んでください。

1.BiT(Big Transfer)の概要

1.1.モデルの種類

事前学習データセットモデルの大きさの違いによって分類されています。事前学習データセットの種類は3、モデルの大きさの種類は5なので、3×5=15種類のモデルがあります。

① 事前学習データセットの違い

事前学習データセットの違いによって以下3種類に分かれます。

  • BiT-S(事前学習データ:ImageNet-1k)
  • BiT-M(事前学習データ:ImageNet-21k)
  • BiT-L(事前学習データ:JFT-300M)
② モデルの大きさの違い

BiTの構造はResNetをベースにしていて、以下5種類に分かれます。

  • R50x1(ResNet-50の幅を1倍)
  • R50x3(ResNet-50の幅を3倍)
  • R101x1(ResNet-101の幅を1倍)
  • R101x3(ResNet-101の幅を3倍)
  • R152x4(ResNet-152の幅を4倍)
公開されているモデル

BiT-S,MのR50x1~R152x4の計10モデルがTensorflow Hubで公開されています。

1.2.BiT Hyper-Rule

BiT Hyper-RuleとはBiTを実装するにあたって実施する前処理方法や学習時のパラメータを規定したもので、ハイパーパラメータチューニングの手間が省くためのものです。具体的には以下を規定しており、学習するデータセットの大きさに応じて決定するようです。

  • 入力画像の大きさ
  • MixUpの使用の有無
  • 学習ステップ数および学習率スケジュール
① 入力画像の大きさ

BiT Hyper-Ruleでは訓練するデータセットの大きさに応じて、以下の通りリサイズやクロップ(切り抜き)をするようです。なお、学習データはリサイズとクロップの両方を実施しますが、テストデータに対してはリサイズのみ実施します。

例えばCIFAR10の場合は入力画像の大きさは32×32pxなので、以下の通り学習データをResize⇒Cropします。

(画像はCifar10より)

② MixUpの使用の有無

学習データの数が20,000件より大きい場合にMixUpを使用します。
MixUpにかかわる解説は以下が良くまとまっています。
複数の画像を組み合わせるオーグメンテーション (mixup, CutMix) - け日記

③ 学習ステップ数および学習率スケジュール

学習ステップ数は学習画像数に応じて決定します。具体的には以下の通りです。

また、学習率は学習スケジュールの30%、60%、90%のタイミングで1/10にします。

2.BiTの実装方法(CIFAR10でファインチューニング)

BiTの実装方法についてまとめます。なお、実装にあたって以下のサンプルコードを参考にしていますが、サンプルコードはtf_flowersを使用しているのに対して、本記事ではCIFAR10を対象にファインチューニングをしています。そのため、画像の前処理の方法等を変更しています。また、個人的に使いやすいように色々変更しています。
big_transfer/big_transfer_tf2.ipynb at master · google-research/big_transfer · GitHub

2.1.CIFAR10の画像データをインポート

from tensorflow.keras.datasets import cifar10

# CIFAR-10のインポート
(x_train_load, y_train), (x_test_load, y_test) = cifar10.load_data()

2.2.入力画像のリサイズ

入力画像の大きさに従って画像をリサイズ・クロップします。CIFAR10の場合は32pxなので160pxにリサイズします。なお、クロップについてはデータ拡張(Data Augmentation)時に実施します。

import numpy as np
import cv2

# 変数定義
IMAGE_SIZE  = 32

# BiT Hyper-Ruleに則って画像サイズを定義
if IMAGE_SIZE <= 96:
    RESIZE_TO = 160
    CROP_TO = 128
else:
    RESIZE_TO = 512
    CROP_TO = 480

# 画像をリサイズ(拡大)する関数を定義
def upscale(image):
    size = len(image)
    data_upscaled = np.zeros((size, RESIZE_TO, RESIZE_TO, 3,))
    for i in range(len(image)):
        data_upscaled[i] = cv2.resize(image[i], dsize=(RESIZE_TO, RESIZE_TO), interpolation=cv2.INTER_CUBIC)
    image = np.array(data_upscaled, dtype=np.int)
    
    return image

# 画像リサイズ
x_train = upscale(x_train_load)
x_test  = upscale(x_test_load)

2.3.BiTモデルの定義

① 使用モデルについて

Tensorflow Hubで公開されているBiT-MのR50x1(ImageNet-21kで事前学習されたResNet-50同等規模のモデル)を使用。BiT-MでもBiT-Sでもパラメータの数は変わらないので、より性能の良いBiT-Mを使用することにしました。どうでもいいですが、BiT-Sの存在価値は良く分からないです…。
公式のサンプルコードではモデルを呼び出す際に「trainable=True」を指定していないですが、これを指定しないと重みが更新されずファインチューニングされないはずなので、「trainable=True」を指定しています。

bit_model = hub.KerasLayer("https://tfhub.dev/google/bit/m-r50x1/1", trainable=True)
② 学習率のスケジューリング

optimizerの定義とあわせて学習率のスケジュールを実装します。学習率は学習スケジュールの30%、60%、90%のタイミングで1/10にします。

BiTモデルの定義(ソースコード
import tensorflow_hub as hub
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import optimizers

# 変数定義
NUM_CLASSES = 10 # CIFAR10の予測対象クラス数(=10)
DATASET_SIZE = len(x_train_load) # CIFAR10の訓練用データ数は50k

# BiT Hyper-Ruleに則って学習ステップ数を定義
if DATASET_SIZE < 20000:
    SCHEDULE_LENGTH = 500
    SCHEDULE_BOUNDARIES = [200, 300, 400]
elif DATASET_SIZE < 500000:
    SCHEDULE_LENGTH = 10000
    SCHEDULE_BOUNDARIES = [3000, 6000, 9000]
else:
    SCHEDULE_LENGTH = 20000
    SCHEDULE_BOUNDARIES = [6000, 12000, 18000]

# BiTモデル定義
def buildModel_BiT():
    bit_model = hub.KerasLayer("https://tfhub.dev/google/bit/m-r50x1/1", trainable=True)
    model = tf.keras.Sequential([
        bit_model,
        tf.keras.layers.Dense(NUM_CLASSES, kernel_initializer='zeros', activation="softmax")
    ],
    name = 'BiT')
    
    lr = 0.003 * BATCH_SIZE / 512 
    
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=SCHEDULE_BOUNDARIES, 
                                                                   values=[lr, lr*0.1, lr*0.001, lr*0.0001])
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
    
    model.compile(optimizer=optimizer,
                  loss="categorical_crossentropy",
                  metrics=["accuracy"])
    
    return model

2.4.訓練用関数の定義

学習用データを訓練データと評価データに分割し、データ拡張(Data Augmentation)をし、訓練を実行する関数を定義します。データ拡張の際にBiT Hyper-Ruleに則ってMixUpとCropをするのですが、デフォルトのImageDataGenerator(tensorflow.keras.preprocessing.image.ImageDataGenerator)ではMixUpやCropはできないので、ImageDataGeneratorを継承した独自ジェネレータを使用しています。独自ジェネレータの作成にあたってこちらを参考にしました。本章では独自ジェネレータのソースについては割愛しますが、ソースを見たい方は参考をご確認ください。

# 訓練用関数の定義
def train_BiT(X, y, STEPS_PER_EPOCH, SCHEDULE_LENGTH, BATCH_SIZE):
    # 訓練データと評価データの分割
    X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, stratify=y, shuffle=True)
    y_train = to_categorical(y_train)
    y_valid = to_categorical(y_valid)
    
    # Data Augmentation
    datagen = MyImageDataGenerator(horizontal_flip=True,
                                   mix_up_alpha=0.1, # データ数<20kの場合はMixUpは実施しない。
                                   random_crop=(CROP_TO, CROP_TO)
                                  )
    train_generator = datagen.flow(X_train, y_train,batch_size=BATCH_SIZE)
        
    # モデル構築
    model = buildModel_BiT()

    # 学習
    history = model.fit(train_generator,
                        steps_per_epoch=STEPS_PER_EPOCH,
                        epochs=10, #公式の推奨値はint(SCHEDULE_LENGTH/STEPS_PER_EPOCH)。時間がかかりすぎるので便宜上10を設定。
                        validation_data=(X_valid, y_valid),
                        shuffle=True
                       )
    
    return model, history

2.5.訓練実行

バッチサイズの公式推奨値は512なのですが、GPUメモリが足りずResourceExhaustedErrorになってしまうので、128まで下げています。(私の環境はGPUメモリを8GB積んでいるのですが、256だとエラーになっちゃいました。)

# 訓練開始
BATCH_SIZE = 128 #公式の推奨値は512。GPUメモリが足りないので仮で128を設定。
SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE
STEPS_PER_EPOCH = 10

model, history = train_BiT(x_train, y_train, STEPS_PER_EPOCH, SCHEDULE_LENGTH, BATCH_SIZE)


参考:ソースコード全文

# 再現性確保
import os
os.environ['PYTHONHASHSEED'] = '0'
import tensorflow as tf
os.environ['TF_DETERMINISTIC_OPS'] = 'true'
os.environ['TF_CUDNN_DETERMINISTIC'] = 'true'

import numpy as np
import random as rn

SEED = 123
def reset_random_seeds():
    tf.random.set_seed(SEED)
    np.random.seed(SEED)
    rn.seed(SEED)

reset_random_seeds()    

session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=32, inter_op_parallelism_threads=32)
tf.compat.v1.set_random_seed(SEED)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)


# ライブラリインポート
import pandas as pd

import tensorflow_hub as hub

from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import cifar10
import cv2
import matplotlib.pyplot as plt

from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import optimizers

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix


# CIFAR-10のインポート
(x_train_load, y_train), (x_test_load, y_test) = cifar10.load_data()


# 変数定義
## CIFAR10の画像サイズ、データ数、予測対象クラス数を格納
IMAGE_SIZE  = 32
DATASET_SIZE = len(x_train_load) # CIFAR10の訓練用データ数は50k
NUM_CLASSES = 10

## BiT Hyper-Ruleに則って画像サイズを定義
if IMAGE_SIZE <= 96:
    RESIZE_TO = 160
    CROP_TO = 128
else:
    RESIZE_TO = 512
    CROP_TO = 480

## BiT Hyper-Ruleに則って学習ステップ数を定義
if DATASET_SIZE < 20000:
    SCHEDULE_LENGTH = 500
    SCHEDULE_BOUNDARIES = [200, 300, 400]
elif DATASET_SIZE < 500000:
    SCHEDULE_LENGTH = 10000
    SCHEDULE_BOUNDARIES = [3000, 6000, 9000]
else:
    SCHEDULE_LENGTH = 20000
    SCHEDULE_BOUNDARIES = [6000, 12000, 18000]


# 画像をリサイズ(拡大)する関数を定義
def upscale(image):
    size = len(image)
    data_upscaled = np.zeros((size, RESIZE_TO, RESIZE_TO, 3,))
    for i in range(len(image)):
        data_upscaled[i] = cv2.resize(image[i], dsize=(RESIZE_TO, RESIZE_TO), interpolation=cv2.INTER_CUBIC)
    image = np.array(data_upscaled, dtype=np.int)
    
    return image


# 画像リサイズ
x_train = upscale(x_train_load)
x_test  = upscale(x_test_load)
# データ正規化
# BiTでは0~1で表現された画像を使用
x_train  = np.array(x_train/255, dtype=np.float32)
x_test  = np.array(x_test/255, dtype=np.float32)

# リサイズ前後の画像を比較
plt.subplot(121).imshow(x_train_load[0])
plt.subplot(122).imshow(x_train[0])
plt.show()
print("※左図:リサイズ前(32*32)、右図:リサイズ後(160*160)")


# ImageDataGeneratorを継承してMix-upやRandom Croppingのできる独自ジェネレーターを定義
# 参考 https://qiita.com/koshian2/items/909360f50e3dd5922f32
class MyImageDataGenerator(ImageDataGenerator):
    def __init__(self, featurewise_center = False, samplewise_center = False, 
                 featurewise_std_normalization = False, samplewise_std_normalization = False, 
                 zca_whitening = False, zca_epsilon = 1e-06, rotation_range = 0.0, width_shift_range = 0.0, 
                 height_shift_range = 0.0, brightness_range = None, shear_range = 0.0, zoom_range = 0.0, 
                 channel_shift_range = 0.0, fill_mode = 'nearest', cval = 0.0, horizontal_flip = False, 
                 vertical_flip = False, rescale = None, preprocessing_function = None, data_format = None, validation_split = 0.0, 
                 random_crop = None, mix_up_alpha = 0.0):
        # 親クラスのコンストラクタ
        super().__init__(featurewise_center, samplewise_center, featurewise_std_normalization, samplewise_std_normalization, zca_whitening, zca_epsilon, rotation_range, width_shift_range, height_shift_range, brightness_range, shear_range, zoom_range, channel_shift_range, fill_mode, cval, horizontal_flip, vertical_flip, rescale, preprocessing_function, data_format, validation_split)
        # 拡張処理のパラメーター
        # Mix-up
        assert mix_up_alpha >= 0.0
        self.mix_up_alpha = mix_up_alpha
        # Random Crop
        assert random_crop == None or len(random_crop) == 2
        self.random_crop_size = random_crop

    # ランダムクロップ
    # 参考 https://jkjung-avt.github.io/keras-image-cropping/
    def random_crop(self, original_img):
        # Note: image_data_format is 'channel_last'
        assert original_img.shape[2] == 3
        if original_img.shape[0] < self.random_crop_size[0] or original_img.shape[1] < self.random_crop_size[1]:
            raise ValueError(f"Invalid random_crop_size : original = {original_img.shape}, crop_size = {self.random_crop_size}")

        height, width = original_img.shape[0], original_img.shape[1]
        dy, dx = self.random_crop_size
        x = np.random.randint(0, width - dx + 1)
        y = np.random.randint(0, height - dy + 1)
        return original_img[y:(y+dy), x:(x+dx), :]

    # Mix-up
    # 参考 https://qiita.com/yu4u/items/70aa007346ec73b7ff05
    def mix_up(self, X1, y1, X2, y2):
        assert X1.shape[0] == y1.shape[0] == X2.shape[0] == y2.shape[0]
        batch_size = X1.shape[0]
        l = np.random.beta(self.mix_up_alpha, self.mix_up_alpha, batch_size)
        X_l = l.reshape(batch_size, 1, 1, 1)
        y_l = l.reshape(batch_size, 1)
        X = X1 * X_l + X2 * (1-X_l)
        y = y1 * y_l + y2 * (1-y_l)
        return X, y

    def flow(self, x, y=None, batch_size=32, shuffle=True, sample_weight=None,
             seed=None, save_to_dir=None, save_prefix='', save_format='png', subset=None):
        batches = super().flow(x=x, y=y, batch_size=batch_size, shuffle=shuffle, sample_weight=sample_weight,
                               seed=seed, save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format, subset=subset)
        # 拡張処理
        while True:
            batch_x, batch_y = next(batches) 
            if self.mix_up_alpha > 0:
                while True:
                    batch_x_2, batch_y_2 = next(batches)
                    m1, m2 = batch_x.shape[0], batch_x_2.shape[0]
                    if m1 < m2:
                        batch_x_2 = batch_x_2[:m1]
                        batch_y_2 = batch_y_2[:m1]
                        break
                    elif m1 == m2:
                        break
                batch_x, batch_y = self.mix_up(batch_x, batch_y, batch_x_2, batch_y_2)
            # Random crop
            if self.random_crop_size != None:
                x = np.zeros((batch_x.shape[0], self.random_crop_size[0], self.random_crop_size[1], 3))
                for i in range(batch_x.shape[0]):
                    x[i] = self.random_crop(batch_x[i])
                batch_x = x
            # 返り値
            yield (batch_x, batch_y)


# BiTモデル定義
def buildModel_BiT():
    bit_model = hub.KerasLayer("https://tfhub.dev/google/bit/m-r50x1/1", trainable=True)
    model = tf.keras.Sequential([
        bit_model,
        tf.keras.layers.Dense(NUM_CLASSES, kernel_initializer='zeros', activation="softmax")
    ],
    name = 'BiT')
    
    lr = 0.003 * BATCH_SIZE / 512 
    
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=SCHEDULE_BOUNDARIES, 
                                                                   values=[lr, lr*0.1, lr*0.001, lr*0.0001])
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
    
    model.compile(optimizer=optimizer,
                  loss="categorical_crossentropy",
                  metrics=["accuracy"])
    
    return model

# 訓練用関数の定義
def train_BiT(X, y, STEPS_PER_EPOCH, SCHEDULE_LENGTH, BATCH_SIZE):
    # 訓練データと評価データの分割
    X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, stratify=y, shuffle=True)
    y_train = to_categorical(y_train)
    y_valid = to_categorical(y_valid)
    
    # Data Augmentation
    datagen = MyImageDataGenerator(horizontal_flip=True,
                                   mix_up_alpha=0.1, # データ数<20kの場合はMixUpは実施しない。
                                   random_crop=(CROP_TO, CROP_TO)
                                  )
    train_generator = datagen.flow(X_train, y_train,batch_size=BATCH_SIZE)
        
    # モデル構築
    model = buildModel_BiT()

    # 学習
    history = model.fit(train_generator,
                        steps_per_epoch=STEPS_PER_EPOCH,
                        epochs=10, #公式の推奨値はint(SCHEDULE_LENGTH/STEPS_PER_EPOCH)。時間がかかりすぎるので便宜上10を設定。
                        validation_data=(X_valid, y_valid),
                        shuffle=True
                       )
    
    return model, history

# 訓練開始
BATCH_SIZE = 128 #公式の推奨値は512。GPUメモリが足りないので仮で128を設定。
SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE
STEPS_PER_EPOCH = 10

model, history = train_BiT(x_train, y_train, STEPS_PER_EPOCH, SCHEDULE_LENGTH, BATCH_SIZE)


# 予測
X = x_test
pred = model.predict(X)

# 予測結果の確認
df_pred = pd.DataFrame(pred)
pred = np.array(df_pred.idxmax(axis=1))
df_pred = pd.DataFrame(pred)
df_y = pd.DataFrame(y_test)
df_result = pd.concat([df_y, df_pred], axis=1, join_axes=[df_y.index])
df_result.columns = ['y','pred']
display(df_result)

# 予測結果の評価(混合行列、Accuracy、Precision、Recall、F_score)
print('Confusion Matrix:')
print(confusion_matrix(df_result['y'],df_result['pred']))
print()
print('Accuracy :{:.4f}'.format(accuracy_score(df_result['y'],df_result['pred'])))
print('Precision:{:.4f}'.format(precision_score(df_result['y'],df_result['pred'],average='macro')))
print('Recall   :{:.4f}'.format(recall_score(df_result['y'],df_result['pred'],average='macro')))
print('F_score  :{:.4f}'.format(f1_score(df_result['y'],df_result['pred'],average='macro')))