Keras: CNN画像分類 (Pre-trained CNN Model)

Takami Torao Python 3.5 Keras 2.2 TensorFlow 1.8 #Keras #TensorFlow #VGG16
  • このエントリーをはてなブックマークに追加
Query Image
0%
43.9%: mushroom / キノコ
キノコ
32.6%: bolete / イグチ科のキノコ
イグチ科のキノコ
16.3%: agaric / ハラタケ科のキノコ
ハラタケ科のキノコ
3.2%: hen-of-the-woods / マイタケ
マイタケ
3.1%: stinkhorn / スッポンタケ科のキノコ
スッポンタケ科のキノコ

TODO: Python を常駐すればロード時間は省略可能。

目的

CNN モデルを使用した画像分類スクリプトを作成する。Keras は自分でニューラルネットワークを組み立てデータセットを用意してモデルを構築することもできるが、過去のコンペで優秀な成績を収めたいくつかの CNN がすぐに利用可能な形で含まれていて、推測部分にフォーカスして試すのであればそれらを利用するのが早い。

Keras にバンドルされている CNN はいずれも ImageNet のデータセットに基づいた 1,000 分類ラベルで学習済みである。もしこの分類ラベルに (運良く) 目的とする分類が含まれていれば利用可能な精度が出ているかをすぐに試すことができるだろう (分類ラベル一覧参照)。含まれていなかった場合でも、これらのモデルから目的の出力層とデータセットで転移学習を行うことで 1 から学習するよりはるかに効率的で高精度なモデルを作成することができるだろう。

ImageNet はスタンフォード大学が研究目的で収集/分類した画像処理のための大規模データセットである。物体検出や識別精度を競うコンテスト ILSVRCKaggle でもよく利用されている。ImageNet のデータセットを使用した ILSVRC 2012 の結果は CNN による高精度な分類で深層学習がブレイクするきっかけとなった。2018年6月時点で 1,420万画像 / 21,841分類が登録されている。

利用可能なモデル

Keras 2.2.0 ではオープンライセンスで利用可能な以下の CNN が用意されており、それぞれ ImageNet で学習済みのモデルが用意されている。バージョンが上がると新たなモデルも追加されるため最新の状況 Keras Document - Applications を参照。

Model Size Group License Description
VGG16 Oxford CC BY 4.0 2014年ILSVRC
VGG19 548MB Oxford CC BY 4.0 VGG16の畳み込み層の調整版
Xception 88M Google MIT Keras作者考案
ResNet50 Microsoft Research MIT 2014年ILSVRC 物体検知部門優勝
InceptionV3 Google Apache 2014年ILSVRC 分類部門優勝
InceptionResNetV2 214MB Apache
MobileNet 17MB Apache 軽量
DenseNet BSD 3-Cause
NASNet Apache
MobileNetV2 13MB Apache 軽量

これらの CNN に対して ImageNet で学習したモデルは Keras で使用する最初の一回目にダウンロードされ ~/.keras/models/ に hdf5 形式で保存されている。

実装

ここでの実装は ImageNet の検証データで正解率の高かった InceptionResNetV2 を使用している。API が統一されているため他の CNN モデルや自分で学習したモデルでもほぼ同じ実装でよい。

Python 3.5 環境で pip install keras tensorflow Pillow で必要なライブラリをインストールしておく。それぞれのバージョンは以下の通り。

> python
Python 3.5.4 (v3.5.4:3f56838, Aug  8 2017, 02:17:05) [MSC v.1900 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.]
>>> import keras, tensorflow, PIL
>>> keras.__version__
'2.2.0'
>>> tensorflow.__version__
'1.8.0'
>>> PIL.__version__
'5.1.0'

summary() メソッドを使用してモデルのレイヤー構成を表示することができる。例えば構造がシンプルな VGG16 のネットワークは以下のような 23 レイヤーで構成されている。

  1. 入力層: batch × 224[pixels] × 224[pixels] × 3[channels] の画像 (テンソル)
  2. 中間層×5 + 分類器×1
  3. 出力層: batch × 1,000[probs] の分類ラベルそれぞれの生起確率

InceptionResNetV2 はより複雑重厚で 782 レイヤーで構成されている。

>>> from keras.applications.inception_resnet_v2 import InceptionResNetV2
Using TensorFlow backend.
>>> model = InceptionResNetV2()
>>> model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, None, None, 3 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 3 864         input_1[0][0]
__________________________________________________________________________________________________
...
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 1536)         0           conv_7b_ac[0][0]
__________________________________________________________________________________________________
predictions (Dense)             (None, 1000)         1537000     avg_pool[0][0]
==================================================================================================
Total params: 55,873,736
Trainable params: 55,813,192
Non-trainable params: 60,544
__________________________________________________________________________________________________

InceptionResNetV2 の入力層の縦x高さはデフォルト 299x299 であるため入力画像のサイズをそれに合わせている (縦横比が 1 でない場合の埋め方法は未調査)。またカラーチャンネルは RGB ではないため preprocess_input() でモデルに合わせたチャネルに変換する必要がある。初回の実行は 90MB を超えるモデルのダウンロードが行われるため少し時間がかかるだろう。

# -*- encoding: utf-8 -*-
import sys
import numpy as np
from keras.applications.inception_resnet_v2 import InceptionResNetV2, decode_predictions, preprocess_input
from keras.preprocessing.image import load_img, img_to_array

model = InceptionResNetV2()

def predict(files, top=10):
  # すべての画像を読み込み
  images = [load_img(f, target_size=(299, 299)) for f in files]   # 299x299 リサイズ済み RGB 画像の読み込み
  images = [img_to_array(image) for image in images]              # Pillow 形式を Numpy 配列に変換
  images = [preprocess_input(image) for image in images]          # RGB からモデル学習時の入力形式に変換
  images = np.stack(images)

  # すべての画像の分類ごとの生起確率を求める
  probs = model.predict(images)

  # 生起確率の配列を分類ラベルと対応させ上位 n 件を抽出する
  result = []
  for labeled_prob in decode_predictions(probs, top=top):
    # (id, label, probability) × top
    result.append(sorted(labeled_prob, key=lambda x: -x[2])[:top])
  return result

if __name__ == "__main__":
  files = sys.argv[1:]
  if len(files) > 0:
    for file, probs in zip(files, predict(files)):
      print("[%s]" % file)
      for (id, label, prob) in probs:
        print("  %.3f %s" % (prob, label))

実行結果

上記のスクリプトをいくつかのサンプル画像に実行し、確率の高い分類ラベルを表示する。

1. 自動車

推測結果は sports_car が 80% 近くであり、その他も車に関連するラベルが上位に来ていることから自動車の一種と認識されている事がわかる。

Sapmple 1 Mushroom
>python classifier.py vehicle.jpg
[vehicle.jpg]
  0.791 sports_car
  0.039 grille
  0.031 car_wheel
  0.031 racer
  0.021 convertible
  0.005 beach_wagon
  0.004 pickup
  0.002 cab
  0.002 passenger_car
  0.001 limousine
79.1%: sports_car / スポーツカー
スポーツカー
3.9%: grille / ラジエーターグリル
ラジエーターグリル
3.1%: car_wheel / ホイール
ホイール
3.1%: racer / レーシングカー
レーシングカー
2.1%: convertible / 折りたたみルーフ付き自動車
折りたたみルーフ付き自動車

2. タマゴテングタケ

タマゴテングタケそのものは分類ラベルにはないが、bolete が 80% を超え、それ以外もキノコの一種が上位に来ていることからキノコと認識されている事がわかる。

Sapmple 1 Mushroom
>python classifier.py mushroom.jpg
[mushroom.jpg]
  0.803 bolete
  0.105 mushroom
  0.024 agaric
  0.008 stinkhorn
  0.004 gyromitra
  0.003 hen-of-the-woods
  0.002 earthstar
  0.001 coral_fungus
  0.000 cheeseburger
  0.000 buckeye
80.3%: bolete / イグチ科のキノコ
イグチ科のキノコ
10.5%: mushroom / キノコ
キノコ
2.4%: agaric / ハラタケ科のキノコ
ハラタケ科のキノコ
0.8%: stinkhorn / スッポンタケ科のキノコ
スッポンタケ科のキノコ
0.4%: gyromitra / シャグマアミガサタケ属のキノコ
シャグマアミガサタケ属のキノコ

3. 仏像

仏像に対する分類は最上位でも 10% 程度である。これは、仏像やそれに類するブロンズ像のような分類ラベルは学習に使用したデータセットには含まれていないことから。性能の良い CNN は学習済みのラベルに対して強い確率を出す一方で、学習対象から外れた対象物に対して低い確率を示す (つまりエントロピーの高さと CNN の「自信」は逆相関する)。

Sapmple 2 Buddha
>python classifier.py buddha.jpg
[buddha.jpg]
  0.108 whiskey_jug
  0.061 pedestal
  0.056 chime
  0.052 saltshaker
  0.051 pitcher
  0.041 perfume
  0.040 brass
  0.033 altar
  0.030 goblet
  0.023 birdhouse
10.8%: whiskey_jug / ウイスキー甕
ウィスキー甕
6.1%: pedestal / 台座
台座
5.6%: chime / 管鐘
管鐘 5.2%: saltshaker / 塩入れ
塩入れ
5.1%: pitcher / 水差し
水差し

問題と対処

モデルの破損

モデルのダウンロード中にプロセスを強制終了するとローカルに不完全なキャッシュが残ってしまい以後の実行がエラーになることがある。

Traceback (most recent call last):
  File "C:\Users\Takami Torao\AppData\Local\Programs\Python\Python36\lib\site-packages\keras_applications\mobilenet.py", line 328, in MobileNet
    model.load_weights(weights_path)
  File "C:\Users\Takami Torao\AppData\Local\Programs\Python\Python36\lib\site-packages\keras\engine\network.py", line 1171, in load_weights
    with h5py.File(filepath, mode='r') as f:
  File "C:\Users\Takami Torao\AppData\Local\Programs\Python\Python36\lib\site-packages\h5py\_hl\files.py", line 312, in __init__
    fid = make_fid(name, mode, userblock_size, fapl, swmr=swmr)
  File "C:\Users\Takami Torao\AppData\Local\Programs\Python\Python36\lib\site-packages\h5py\_hl\files.py", line 142, in make_fid
    fid = h5f.open(name, flags, fapl=fapl)
  File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py\h5f.pyx", line 78, in h5py.h5f.open
OSError: Unable to open file (truncated file: eof = 40960, sblock->base_addr = 0, stored_eof = 17225924)

この場合 ~/.keras/models/ の該当モデルのファイルを削除し再実行することで回復することができる。

カラーモデルの変換

画像処理向けの一般的な CNN は色情報に 3 次元のカラーモデルを使用するが、学習時に使用したカラーモデルが RGB ではない事がある。実際、Keras に含まれている ImageNet のモデルは BGR である (BGR を標準とする OpenCV をどこかで使っているのかもしれない)。Pillow で読み込んだ画像は RGB カラーモデルであるため、モデルごとに用意されている preprocess_input() で適切なカラーモデルに変換する必要がある。

参照

  1. Keras Documenttion - Applications
  2. Francois Chollet (2018)PythonとKerasによるディープラーニング, マイナビ出版
  3. 太田満久, 須藤広大, 黒澤匠雅, 小田大輔 (2018) TensorFlow開発入門 Kerasによる深層学習モデル構築手法, 翔泳社
  4. Keras: VGG16 の分類ラベル一覧