UNISIA-SE Tech Blog

気まぐれお勉強日記

[Python] [13] 重みに対する損失関数の勾配法と実装サンプル

1. 前提条件


このページでは、重みに対する勾配法(勾配降下法) の実装サンプルを記載する。

以下、必要な前提知識。

▼ 下記ページを理解していること。
[Python] [10] 損失関数と数値微分の実装サンプル
[Python] [11] 偏微分と勾配の実装サンプル
[Python] [12] 勾配降下法の実装サンプル

▼ Python3.6、NumPyがインストールされていること。
このページでは、venvの仮想環境(Python3.6)上にNumPyをインストールした環境で、Python対話モード(Pythonインタプリタ)にて実装サンプルを記載している。

※ Python対話モード、NumPyについては下記を参考。
Python対話モード:[Python] 対話モード (インタプリタ) の使用方法
NumPy:[Python] [NumPy] インストールとnumpy.ndarrayの使用方法

2. 重みに対する勾配法とは


重み は、正解に対して 入力値 がどれだけ影響するかを示す 重要度 を表す。

重み という言葉からは全くイメージできない意味を持つが、これは英語の Weigh を直訳し、重み となっているからであり、本来は 大切さ価値重要性 という意味で命名されている。

そして、この 重み(重要度) は、一般的に Weigh の \(w\) を取り、\(w_{0}\)、\(w_{1}\)、\(w_{2}\) …で表現され 評価的に具体的な数値を入れてみて損失結果を計る ためのパラメータとなる。

ニューラルネットワークでは、この 重み に対して、下記前ページで解説した 勾配降下法 を実施し、最適な 重み を求めていく。

また、下記に示す行列のようにいろんなパターンの 重み を行列計算で一斉に 偏微分 し、最適な 重みを求めていく。

\[ W = \begin{pmatrix} w_{11} & w_{12} & w_{13} \\ w_{21} & w_{22} & w_{23} \\ \end{pmatrix} \]
\[ \frac{∂L}{∂W} = \begin{pmatrix} \frac{∂L}{∂w_{11}} & \frac{∂L}{∂w_{12}} & \frac{∂L}{∂w_{13}} \\ \frac{∂L}{∂w_{21}} & \frac{∂L}{∂w_{22}} & \frac{∂L}{∂w_{23}} \\ \end{pmatrix} \]

\(W\) は、一斉に偏微分しようとしている 2行3列の 重み 達。
\(L\) は、対象となる 損失関数 で \(\displaystyle \frac{∂L}{∂W}\) の各要素で偏微分し、損失関数\(L\) の 勾配 を 求めている。

以降は、この 勾配 を求める実装サンプルについて記載する。

3. 重みに対する勾配のPython実装サンプル


以降は、参考文献『ゼロから作るDeep Learning』から提供されている ch04/gradient_simplenet.py を用いたサンプル解説をしていく。

※サンプルコードは、下記 Git からダウンロードする。
Git (deep-learning-from-scratch):

▼ ここでは、より分かりやすく説明するため、ch04/gradient_simplenet.pysimpleNet クラスを実装するにあたり import されている 下記 (1) ~ (3) の関数 をあえてPython対話モードで定義する形で記載する。

(1) ソフトマックス関数:common/functions.py のsoftmax関数

※ ソフトマックス関数の一般的な定義は下記ページを参考。


$ python
 >>> import numpy as np
 >>>
 >>> def softmax(x):
 ...     if x.ndim == 2:    # 次元が2の場合
 ...         x = x.T    # xの転置行列を x に設定
 ...         x = x - np.max(x, axis=0)    # x から xの最大値(列単位)を減算した結果を x に設定
 ...         y = np.exp(x) / np.sum(np.exp(x), axis=0)    # 「eのx乗」÷「eのx乗」の最大値(列単位)の合計を y に設定
 ...         return y.T    # Tの転置行列を返す
 ...     # 次元が 2 でない場合
 ...     x = x - np.max(x)    # オーバーフロー対策
 ...     return np.exp(x) / np.sum(np.exp(x))    # 「eのx乗」÷「eのx乗」の合計を返す
 ...
 >>>

(2) 交差エントロピー誤差:common/functions.py の cross_entropy_error関数

※ 交差エントロピー誤差の処理内容については、下記ページを参考。


 >>> # 上記対話モードの続き
 >>> def cross_entropy_error(y, t):
 ...     if y.ndim == 1:    # 次元が1の場合
 ...         t = t.reshape(1, t.size)    # t の列数を t のsizeに変更する(形状を変形)
 ...         y = y.reshape(1, y.size)    # yの列数を y のsizeに変更する(形状を変形)
 ...
 ...     # 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換
 ...     if t.size == y.size:
 ...         t = t.argmax(axis=1)    # t の最大要素(行単位)インデックスを設定
 ...
 ...     batch_size = y.shape[0]    # y の初次元要素数を設定
 ...     return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size    # 上記URL(※)を参考。
 ...
 >>>

(3) 勾配処理:common/gradient.py の numerical_gradient関数

※ 勾配の処理内容は下記ページの 勾配関数(num_gradient) を参考。


 >>> # 上記対話モードの続き
 >>> def numerical_gradient(f, x):
 ...     h = 1e-4    # 0.0001
 ...     grad = np.zeros_like(x)    # xと同じ形状の配列で値がすべて 0
 ...
 ...     # ↓↓↓ xの行列すべての要素を1ネストでループできるよう nditer で it を定義している
 ...     # ※ op_flags=['readwrite'] は、ループ内でx[idx]に書き込めるよう書き込みもできるようにしている
 ...     it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
 ...
 ...     while not it.finished:    # it でxの全要素をループ
 ...         idx = it.multi_index
 ...         tmp_val = x[idx]
 ...         x[idx] = float(tmp_val) + h    # ループの該当要素を浮動小数点に変換
 ...         fxh1 = f(x)    # f(x + h)の算出 … (A)
 ...
 ...         x[idx] = tmp_val - h
 ...         fxh2 = f(x) # f(x-h)    # f(x - h)の算出 … (B)
 ...         grad[idx] = (fxh1 - fxh2) / (2*h)
 ...
 ...         x[idx] = tmp_val    # 値を元に戻す
 ...         it.iternext()
 ...
 ...     return grad
 ...
 >>>

▼ 重みに対する勾配処理:上記 (1) ~ (3) の関数 を呼び出した形でch04/gradient_simplenet.pysimplenet クラスを実装する。


 >>> # 上記対話モードの続き
 >>> import sys, os
 >>>
 >>> class simpleNet:
 ...     def __init__(self):
 ...         # 重みパラメータ:2x3行列の標準正規分布(ガウス分布)でランダムな数値を出力する
 ...         self.W = np.random.randn(2,3)
 ...
 ...     def predict(self, x):
 ...         return np.dot(x, self.W)    # x と self.W の積を返す
 ...
 ...     def loss(self, x, t):
 ...         z = self.predict(x)
 ...         y = softmax(z)
 ...         loss = cross_entropy_error(y, t)
 ...
 ...         return loss
 ...
 >>>

「x」は、入力データ「t」教師データ

predict関数 は、入力データ 「x」 と __init__ で設定した仮(ランダム)の重みパラメータ「self.W」の 評価結果(積) を返す。

loss関数 は、predict関数softmax関数 を実施した「y」と教師データ「t」の 損失関数(交差エントロピー誤差:cross_entropy_error関数) を返す。

4. 実装サンプルの実行確認


以降、上記 simpleNet の実行例を基に解説する。

▼ インスタンスの結果確認

 >>> # 上記対話モードの続き
 >>> net = simpleNet()
 >>> print(net.W)    # 重みパラメータ
 [[ 0.49236891 -1.2239298  -1.13722119]
  [ 0.05405365 -0.79897152 -0.08066587]]
 >>>

simpleNet() により、2x3行列のランダムな数値が生成される。

評価結果(積) の確認

 >>> # ↑↑↑ 上記対話モードの続き
 >>> x = np.array([0.6, 0.9])
 >>> p = net.predict(x)
 >>> print(p)
 [ 1.05414809 0.63071653 1.1328074  ]
 >>> np.argmax(p)    # 最大値のインデックス
 2
 >>>

predict(x) により、入力データ [ 0.6, 0.9 ] と重みパラメータ「net.W」の 評価結果(積) を算出している。
(最大インデックスは、2)

損失関数 の結果確認

 >>> # 上記対話モードの続き
 >>> t = np.array([0, 0, 1])    # 正解ラベル
 >>> net.loss(x, t)
 0.92806853663411326
 >>>

上記で最大インデックスとなった 2 が正解ラベルとなる状態で損失関数の結果は、約 0.93

勾配 の結果確認

 >>> # 上記対話モードの続き
 >>> def f(W):
 ...     return net.loss(x, t)
 ...
 >>> dW = numerical_gradient(f, net.W)
 >>> print(dW)
 [[ 0.21924763  0.14356247 -0.36281009 ]
  [ 0.32887144  0.2153437 -0.544211514]]
 >>>

net.loss(x, t)f(W) とし、勾配処理(numerical_gradient) を実施している。

f(W)W は、(3) 勾配処理:common/gradient.py の numerical_gradient関数 の \((A)\), \((B)\)と整合性が取れるように定義したもの。

▼ 結果から見る結論
重みパラメータ \(w_{11}\) と重みパラメータ \(w_{23}\) にスポットを当てた結果を見る。

\(w_{11}\) が \(h\)分増加すると、\(\displaystyle \frac{∂L}{∂W}\) の \(\displaystyle \frac{∂L}{∂w_{11}}\) は、約 0.2 となり 0.2 h 増加 しています。

一方、\(w_{23}\) は \(h\) 分増加すると、\(\displaystyle \frac{∂L}{∂W}\) の \(\displaystyle \frac{∂L}{∂w_{23}}\) が 約 -0.5 となり 0.5h 減少している。

よって、
重みパラメータ \(w_{11}\) は、マイナス方向 に。重みパラメータ \(w_{23}\) は、プラス方向更新すべき という結論となる。

※ 以上の要領で重みパラメータを より損失が少ない重みパラメータへ 更新していくことが目的。

以上。


【参考文献】
斎藤 康毅 (2018) 『ゼロから作るDeep Learning - Pythonで学ぶディープラーニングの理論と実装』株式会社オライリー・ジャパン


Copyright UNISIA-SE All Rights Reserved.
s-hama@unisia-se.jp