UNISIA-SE Tech Blog

気まぐれお勉強日記

[Python] [9] ミニバッチ学習 (交差エントロピー誤差) の実装サンプル

1. 前提条件


このページでは、交差エントロピー誤差のミニバッチ学習についての簡単な実装サンプルを記載する。

以下、必要な前提知識。

▼ 下記ページを理解していること。
[Python] [8] 損失関数 (2 乗和誤差、交差エントロピー誤差) と実装サンプル

▼ Python3.6、NumPyがインストールされていること。
このページでは、venvの仮想環境(Python3.6)上にNumPyをインストールした環境で、Python対話モード(Pythonインタプリタ)にて実装サンプルを記載している。

※ Python対話モード、NumPyについては下記を参考。
Python対話モード:[Python] 対話モード (インタプリタ) の使用方法
NumPy:[Python] [NumPy] インストールとnumpy.ndarrayの使用方法

2. ミニバッチ学習とは


機械学習では、下記 MNIST(※) のような訓練データすべて (学習用データセット 60,000枚) を対象に損失関数を求める必要がある。

60,000枚くらいであれば、それほど問題にならないが、ビッグデータともなれば 数千万のデータ となり、すべて求めると処理時間もサーバー負荷も膨大となり現実的でない。

また、高負荷の割には、100件くらいのランダム抽出した結果と大して変わらず機械学習では、数千万データの近似値として十分有効である。
この学習方法を機械学習の分野では、ミニバッチ学習 と呼び、テレビの視聴率計測など一般的に広く使用されてる。

3. 交差エントロピー誤差のミニバッチ学習 (定義)


下記 (A) は、前ページ (※) でMNISTを例に解説した 交差エントロピー誤差 の定義。

\[ {\normalsize E = -\sum_{i=1}^{n} t_{k} \log \ y_{k}…(A) } \]
\(t_{k}\) : 訓練データ
\(y_{k}\) : ニューラルネットワークの出力
\(k\) : データの次元数

これは、一つのデータ(数字 0 ~ 9 のいずれか)に対して、ニューラルネットワークの出力が 10個の配列(正解予想) と、訓練データの出力が10個の配列(正解が1 、不正解が0) となる損失関数を表している。

これをすべてのデータに対して実施し、その和を表現すると下記 (B) の定義となる。
\[ {\normalsize E = -\frac{1}{N}\sum_{i=1}^{n}\sum_{j=1}^{k} t_{nk} \log \ y_{nk}…(B) } \]
\(N\) : データの個数
※ MNISTの場合、学習用データセットの60,000個。一つあたりの損失平均となるようにNで割る。

\(k\) : データの次元数
※ MNISTの場合、訓練データの種類(数字 0 ~ 9 に対応する10個)

\(t_{nk}\) : 訓練データである\(t_{k}\) が \(N\)個分。
※ MNISTの場合、学習用ラベルデータセットの60,000個。

\(y_{nk}\) : ニューラルネットワークの出力である \(y_{k}\) が \(N\)個分 。
※ MNISTの場合、学習用ラベルデータセットの60,000個。

4. 交差エントロピー誤差のミニバッチ学習 (MNISTの準備)


次にMNISTを使った ミニバッチ学習の準備データの内容 について解説する。

MNIST概要やリポジトリクローンについては下記を参考。

MNISTの 学習用データセット検証用データセット をダウンロードする。

$ cd gitlocalrep    # ローカルのGitリポジトリに移動
$ cd deep-learning-from-scratch/ch03    # Git (deep-learning-from-scratch) のカレントディレクトに移動に移動
$ python
 >>> import sys, os
 >>> sys.path.append(os.pardir)
 >>> import numpy as np
 >>> from dataset.mnist import load_mnist
 >>>
 >>> (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
 >>>
 >>> print(x_train.shape)     # 詳細は下記 ※ 2 に記載
 (60000, 784)
 >>> print(t_train.shape)     # 詳細は下記 ※ 3 に記載
 (60000, 10)
 >>>

▼ 補足

※ 1 load_mnist関数の引数
引数 normalize は、入力画像を 0.0 ~ 1.0 に 正規化するかどうか をBool 値で設定する。
Falseの場合、入力画像のピクセルは 0 ~ 255 となる。

引数 flatten は、入力画像を 1次元にするかどうか をBool 値で設定する。
Falseの場合、入力画像は1 * 28 * 28 の3次元配列として格納され、Trueの場合、1次元配列(要素:784)として格納される。

引数 one_hot_labelは、ラベルを one_hot 表現で格納するかどうか をBool 値で設定する。
one_hot 表現の場合は、正解となるラベルのみ 1 でそれ以外は、0 の配列となる。

戻り値は、(訓練画像、訓練ラベル), (テスト画像, テストラベル) の形式でMNISTデータを返する。

※ 2 x_train.shape (形状)
784列(= 28 × 28)の画像データが学習用データセット数の 60,000枚 あることを表している。

※ 3 t_train.shape (形状)
10列(正解となるラベルのみ 1 でそれ以外は、0 の配列)の教師データが学習用データセット数の 60,000個 あることを表している。

5. 交差エントロピー誤差のミニバッチ学習 (Python実装サンプル)


最後に上記で準備したMNISTのデータセットを使い ミニバッチ学習のPython実装サンプル について解説する。

MNISTの学習用画像データセット (60,000枚) の中から 100枚 抜出して、交差エントロピー誤差の損失関数 を求めるサンプル。

まず、前準備として下記ページの 2. 推論バッチ処理の実行準備 で解説した ch03/neuralnet_mnist_batch.py の init_network() と predict(network, x) を定義する。


$ cd gitlocalrep    # ローカルのGitリポジトリに移動
$ cd deep-learning-from-scratch/ch03    # Git (deep-learning-from-scratch) のカレントディレクトに移動に移動
$ python
 >>> import sys, os
 >>> sys.path.append(os.pardir)
 >>> import numpy as np
 >>> import pickle
 >>> from dataset.mnist import load_mnist
 >>> from common.functions import sigmoid, softmax
 >>>
 >>> def init_network():
 ...     with open("sample_weight.pkl", 'rb') as f:
 ...         network = pickle.load(f)
 ...     return network
 ...
 >>> def predict(network, x):
 ...     W1, W2, W3 = network['W1'], network['W2'], network['W3']
 ...     b1, b2, b3 = network['b1'], network['b2'], network['b3']
 ...     a1 = np.dot(x, W1) + b1
 ...     z1 = sigmoid(a1)
 ...     a2 = np.dot(z1, W2) + b2
 ...     z2 = sigmoid(a2)
 ...     a3 = np.dot(z2, W3) + b3
 ...     y = softmax(a3)
 ...     return y
 ...
 >>>

そして、交差エントロピー誤差のミニバッチ学習を定義。

 >>> def cross_entropy_error(y, t):
 ...     if y.ndim == 1:    # 次元が 1 の場合
 ...         t = t.reshape(1, t.size)
 ...         y = y.reshape(1, y.size)
 ...     batch_size = y.shape[0]
 ...     return -np.sum(t * np.log(y + 1e-7)) / batch_size
 ...
 >>>

y は、ニューラルネットワーク(推論バッチ処理) の出力となり、以降の解説で引数 y に predict(network, x_batch) の戻り値を設定する。

次に、MNISTの学習用データセットと検証用データセットをダウンロードする。

 >>> (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
 >>>

次に np.random.choice を使用しランダムで「100枚」抽出する。

 >>> train_size = x_train.shape[0]
 >>> batch_size = 100
 >>> batch_mask = np.random.choice(train_size, batch_size)
 >>> x_batch = x_train[batch_mask]
 >>> t_batch = t_train[batch_mask]
 >>>
 >>> print(batch_mask)    # np.random.choiceの結果
 [ 2759 48331 20881 29315 30035 55711 47969  1338 54067 23424 14789  9722
 38601 10138 24036 23811   284 43467 41042 39683 49572 20247 29728 23176
 50987  4855 43468  7179  2815 29033 46578 25623 41615 34833 12651 35969
 51498 34685 30303 57205 16641 39057 45010 35152 19620 34228 55637 44070
 25063 14112 45717 32403 32209 26388 27572 53492 46367 15161 38462 26947
 30193 45931 25658 24854 33528 41892 55989 32053 43699 22615 42090  3430
  1568 57173 35969 11839 26384 16123 31217 30323 46844 37015 28731 46525
 15412 19736 16773 12655 37365 52095 11550 46947 34077 31528  9691 44021
  6473 41599  7001  4999]
 >>>

次に100枚 抜き出したニューラルネットワーク(推論バッチ処理)の出力結果 を y_batch に取得する。

 >>> network = init_network()
 >>> y_batch = predict(network, x_batch)
 >>>

そして、最後にニューラルネットワーク(推論バッチ処理)の出力 y_batcht_batch を引数に交差エントロピー誤差を求める。

 >>> cross_entropy_error(y_batch, t_batch)
 0.20627920610480943
 >>>

100枚のミニバッチ学習結果は、約 0.2 という結果なった。

ちなみに上記は load_mnistで 引数「one_hot_label=True」を指定した one_hot表現 のミニバッチ処理ですが、one_hot表現でなくラベルのデータセットをダウンロードした場合は、下記の実装となる。

 >>> def cross_entropy_error(y, t):
 ...     if y.ndim == 1:
 ...         t = t.reshape(1, t.size)
 ...         y = y.reshape(1, y.size)
 ...     batch_size = y.shape[0]
 ...     return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size
 ...
 >>>

one_hot表現では、t = 0 のデータはすべて 0 になるが、ラベルデータとなると全てデータが対象となる。

▼ 引数 one_hot_label に関する差異
one_hot_label=True の時

...     return -np.sum(t * np.log(y + 1e-7)) / batch_size

one_hot_label=False の時

...     return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size

以上。


【参考文献】
斎藤 康毅 (2018) 『ゼロから作るDeep Learning - Pythonで学ぶディープラーニングの理論と実装』株式会社オライリー・ジャパン


Copyright UNISIA-SE All Rights Reserved.
s-hama@unisia-se.jp