よっしーの私的空間

機械学習を中心に興味のあることについて更新します

ViTとEfficientnetをCIFAR-10で試してみた

画像分類モデルには色々なものがありますが、個人的にはViT(Vision Transformer)とEfficientnetが気になってます。これらのモデルを実際に動かしてみて、速度や精度等を比較してみたいと思います。ViTとEfficientnetについては以下でまとめましたので、良ければ参考にしてください。
TensorflowによるEfficientNetの実装 - よっしーの私的空間
TensorflowによるViT(Vision Transformer)の実装 - よっしーの私的空間

1. 使用データ

CIFAR-10というデータを使用しました。10種類のカラー画像データで、大きさは32×32ピクセルです。訓練データが50,000データ、テストデータが10,000データ用意されています。
f:id:t-yoshi-book:20210327165147p:plain
(引用:CIFAR-10 and CIFAR-100 datasets

以下のコードで画像データをインポートします。

# CIFAR-10のインポート
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

2. 比較結果

ViTとEfficientnetをfinetuneして速度や精度等を測りました。結果は以下の通りです。
f:id:t-yoshi-book:20210327174054p:plain

全体的にEfficientnetの方がViTよりも良さそうです。おそらくですが、画像サイズが小さすぎるせいなのではないかと考えます。ViTでは画像をパッチに分割して分析しますが、画像サイズ≒パッチサイズになってしまっているので、画像が分割されずViTの効果が十分に発揮されなかったのではないかと推測します。今度は大きめの画像でも試してみようと思います。

3. 実行コード

# 再現性確保
import os
os.environ['PYTHONHASHSEED'] = '0'
import tensorflow as tf

import numpy as np
import random as rn

SEED = 123
def reset_random_seeds():
    tf.random.set_seed(SEED)
    np.random.seed(SEED)
    rn.seed(SEED)

reset_random_seeds()

# ライブラリインポート
import pandas as pd

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D,BatchNormalization,Flatten
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from efficientnet.tfkeras import EfficientNetB0 #使用するモデルにあわせて変更する(B0~B7) 
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
import tensorflow_addons as tfa

from vit_keras import vit, utils

from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import cifar10
import datetime

# CIFAR-10のインポート
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# 変数定義
image_size = 32   #画像のサイズ(CIFAR-10は32×32)
input_shape=(image_size,image_size,3)
num_classes = 10 #画像の種類の数(CIFAR-10は10種類)

# モデル定義
## Efficientnetモデルの定義(ファインチューニング)
def buildModel():
    model = Sequential()
    model.add(EfficientNetB0(    #使用するモデルにあわせて変更する(B0~B7) 
        include_top=False,
        weights='imagenet',
        input_shape=input_shape))
    model.add(GlobalAveragePooling2D())
    model.add(Dense(num_classes, activation="softmax"))
    model.compile(optimizer=optimizers.Adam(learning_rate=1e-4), loss="categorical_crossentropy", metrics=["accuracy"])
    
    return model

## ViTモデルの定義(ファインチューニング)
def buildModel_ViT():
    vit_model = vit.vit_b16(    #使用するモデルにあわせて変更する(b16~l32) 
        image_size = image_size,
        activation = 'sigmoid',
        pretrained = True,
        include_top = False,
        pretrained_top = False)
    model = tf.keras.Sequential([
        vit_model,
        tf.keras.layers.Flatten(),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(21, activation = tfa.activations.gelu),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(num_classes, 'softmax')
    ],
    name = 'vision_transformer')
    model.compile(optimizer=optimizers.Adam(learning_rate=1e-4), loss="categorical_crossentropy", metrics=["accuracy"])
    
    return model

# 訓練用関数の定義
## efficientnet用訓練関数
def train_efficientnet(X, y, steps_per_epoch, epochs, batch_size, callbacks):
    # 訓練データと評価データの分割
    X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, stratify=y, shuffle=True)
    y_train = to_categorical(y_train)
    y_valid = to_categorical(y_valid)

    # Data Augmentation
    datagen = ImageDataGenerator(rotation_range=20, horizontal_flip=True, width_shift_range=0.2, zoom_range=0.2)
    train_generator = datagen.flow(X_train, y_train,batch_size=batch_size)

    # モデル構築
    model = buildModel()

    # 学習
    history = model.fit_generator(train_generator,
                        steps_per_epoch=steps_per_epoch,
                        epochs=epochs,
                        validation_data=(X_valid, y_valid),
                        callbacks=callbacks,
                        shuffle=True
                       )
    
    return model, history

## ViT用訓練関数
def train_vit(X, y, steps_per_epoch, epochs, batch_size, callbacks):
    # 訓練データと評価データの分割
    X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, stratify=y, shuffle=True)
    y_train = to_categorical(y_train)
    y_valid = to_categorical(y_valid)
    
    # Data Augmentation
    datagen = ImageDataGenerator(rotation_range=20, horizontal_flip=True, width_shift_range=0.2, zoom_range=0.2)
    train_generator = datagen.flow(X_train, y_train,batch_size=batch_size)
        
    # モデル構築
    model = buildModel_ViT()

    # 学習
    history = model.fit_generator(train_generator,
                        steps_per_epoch=steps_per_epoch,
                        epochs=epochs,
                        validation_data=(X_valid, y_valid),
                        callbacks=callbacks,
                        shuffle=True
                       )
    
    return model, history

# 訓練開始
steps_per_epoch = 1250
epochs = 1000
batch_size = 32

reduce_lr = ReduceLROnPlateau(monitor='val_accuracy',factor=0.2,patience=2,verbose=1,
                              min_delta=1e-4,min_lr=1e-6,mode='max')
earlystopping = EarlyStopping(monitor='val_accuracy', min_delta=1e-4, patience=5,
                                                 mode='max', verbose=1)
callbacks = [earlystopping, reduce_lr]

print('開始時間:',datetime.datetime.now())
# 以下はvitの場合。Efficientnetの場合はtrain_efficientnet()を呼び出す。
model, history = train_vit(x_train, y_train, steps_per_epoch, epochs, batch_size, callbacks)
print('終了時間:',datetime.datetime.now())

# 予測
preds = []
X = x_test
pred = model.predict(X)

# 予測結果の確認
df_pred = pd.DataFrame(pred)
pred = np.array(df_pred.idxmax(axis=1))
df_pred = pd.DataFrame(pred)
df_y = pd.DataFrame(y_test)
df_result = pd.concat([df_y, df_pred], axis=1, join_axes=[df_y.index])
df_result.columns = ['y','pred']
print(df_result)

# 予測結果の評価(混合行列、Accuracy、Precision、Recall、F_score)
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
print('Confusion Matrix:')
print(confusion_matrix(df_result['y'],df_result['pred']))
print('Accuracy :{:.4f}'.format(accuracy_score(df_result['y'],df_result['pred'])))
print('Precision:{:.4f}'.format(precision_score(df_result['y'],df_result['pred'],average='macro')))
print('Recall   :{:.4f}'.format(recall_score(df_result['y'],df_result['pred'],average='macro')))
print('F_score  :{:.4f}'.format(f1_score(df_result['y'],df_result['pred'],average='macro')))