CIFAR-10

Takami Torao Python 3.5 Keras 2.2 #CIFAR10 #CNN
  • このエントリーをはてなブックマークに追加

概要

CIFAR-10 は 10 種に分類された 32×32 の 60,000 画像からなるデータセット。80 Million Tiny Images から画像認識のために抽出/分類したサブセットである。

airplane
airplane
automobile
automobile
bird
bird
cat
cat
deer
deer
dog
dog
frog
frog
horse
horse
ship
ship
truck
truck

Keras を使って読み込む

Keras 2.2 は標準で用意されているデータセットの load_data() を利用するだけで CIFAR-10 を使用することができる。

from keras.datasets.cifar10 import load_data

(x_train, y_train), (x_test, y_test) = load_data()
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)

この返値はテンソルの形になっていて、32×32 サイズの RGB 画像が訓練用に 50,000 データ、検証用に 10,000 データ使用できる。

(50000, 32, 32, 3) (10000, 32, 32, 3) (50000, 1) (10000, 1)

ただし、0-9 の分類インデックスに対するラベルを参照する機能は用意されていないようなので、後述する Pickle を使って batches.meta を読み込むか、オンコードでリテラル ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"] を記述する。

Pickle を使って読み込む

自力で画像データをデコードする場合は CIFAR-10 サイトから CIFAR-10 python version をダウンロードする。より詳細に分類された CIFAR-100 python version もあるので目的に応じて使用すれば良い。

Python 版のデータセットは 1000 件ごとに 1 ファイルに納められている。それぞれ学習用のデータ data_batch_n とテスト用のデータ test_batch が対応する。アーカイブを解凍すると以下の内容の dict 構造を持った Pickle ファイルが得られる。

batches.meta
{
  "num_cases_per_batch": 10000,
  "label_names": [
    "airplane", "automobile",
    "bird", "cat", "deer",
    "dog", "frog", "horse",
    "ship", "truck"
  ],
  "num_vis": 3072
}
data_batch_[1-5], test_batch
{
  "batch_label": "training batch 1 of 5",
  "labels": [6, 9, 9, 4, 1, 1, 2, 7, ... ],
  "data": [[ 59  43  50 ... 140  84  72]
   [154 126 105 ... 139 142 144]
   [255 253 253 ...  83  83  84]
   ...
   [ 71  60  74 ...  68  69  68]
   [250 254 211 ... 215 255 254]
   [ 62  61  60 ... 130 130 131]],
  "filenames": [
    "leptodactylus_pentadactylus_s_000004.png",
    "camion_s_000148.png",
    "tipper_truck_s_001250.png",
    "american_elk_s_001521.png",
    ...
  ]
}

読み出しは以下のように行うことができるが、Python 2 との互換性のためか文字列がすべてバイト列型で保存されているため Python 3 から読み込むときは文字列型から変換する必要がある。

def unpickle(file):
  import pickle
  with open(file, "rb") as fo:
    dict = pickle.load(fo, encoding="bytes")
  return dict

データファイルの data フィールドに numpy 形式のデータが 3072 (32×32×3) × 10000 チャネル分保存されている。これはピクセルと RGB チャネルが flat 化されて 1 画像あたりの shape が (3072,) となったフォーマットである。以下の操作を行うことで flat 化された画像データを (32, 32, 3) に復元することができる。

dict[b"data"][i].reshape(3, 32, 32).transpose(1, 2, 0)

CIFAR-10 サイトから取得した CIFAR-10 python version を Python 3 で読み込む場合は文字列のキーを bytes リテラルで指定する必要がある。

# -*- encoding: utf-8 -*-

def unpickle(file):
  import pickle
  with open(file, 'rb') as fo:
    dict = pickle.load(fo, encoding='bytes')
  return dict

if __name__ == "__main__":
  import sys
  import matplotlib.pyplot as plt
  import random

  dict = unpickle(sys.argv[1])
  print("DATA SIZE: %d" % len(dict[b"data"]))
  print("DATA SHAPE: %s" % str(dict[b"data"][0].shape))

  # ランダムに選んだ画像を PyPlot で表示
  images = dict[b"data"]
  i = random.randrange(len(images))
  plt.imshow(images[i].reshape(3, 32, 32).transpose(1, 2, 0))
  plt.title("DATA[%d]" % i)
  plt.show()

実行結果:

$ python list.py cifar-10-batches-py/data_batch_1
DATA SIZE: 10000
DATA SHAPE: (3072,)

種類は不明だが小型の鳥類の画像に復元できたことが分かる。

data_batch_1[9047]
Fig. 1: data_batch1[9047]

参照

  1. CIFAR-10 (Wikipedia)
F