ランドML

機械学習とかの備忘録

交差エントロピーの計算(nn.CrossEntropy)を調べてみた

はじめに

以下の公式ドキュメントを読みながら本記事を書きました。
ある程度理解したつもりですが、間違っていたらすみません。 pytorch.org

交差エントロピー(nn.CrossEntropy)の計算式

考えやすくするため、本記事では重みづけは行っていません。
そのため、参考サイトの文字を一部変更して、

  • N : サンプル数
  • n : サンプル番号(1~N)
  • C : クラス数
  •  y_{n} : nサンプル目の正解ラベル番号
  •  x_{n, i} : nサンプル目のラベルiのロジット

とすれば、n番目のサンプルの交差エントロピーは、

 l_n = -log \frac{\exp(x_{n, y_n})} { \sum _ {i=1} ^ {C} \exp(x _ {n, i}) }

と計算されます。そして出力は全体の平均であるため、

 loss = \frac{ \sum _ {i=1} ^ {N} l_n }{N}

が出力されます。

計算例

理論だけでは分かりづらいので、適当な例を挙げて計算していきます。 ここでは、ラベルが「リンゴ」、「バナナ」、「ブドウ」の順に1~3となっている、画像の分類をしたとし、その結果を以下に示します。
(正解ラベルの箇所を緑にしているため、 y_1 = 1, y_2 = 2)

n = 1について

表より、

  •  y_1 = 1
  •  x_{1,1}=2.0
  •  x_ {1,2} = 0.5
  •  x_ {1,3} = 0.1

だから、

 l_1 = -log \frac{\exp(x_{n, y_1})} { \sum _ {i=1} ^ {3} \exp(x _ {n, i})}

 = -log \frac{\exp(2.0)} { \exp(2.0) + \exp(2.0) + \exp(0.1) }

 \sim -log (0.728)

 \sim 0.32

となります(有効数字はガバガバですが...)。

n = 2について

先ほどと同様に表より、

  •  y_2 = 2
  •  x_{2,1} = 0.5
  •  x_ {2,2} = 2.0
  •  x_ {2,3} = 0.7

だから、

 l_2 = -log \frac{\exp(2.0)} { \exp(0.5) + \exp(2.0) + \exp(0.7) }

 \sim -log (0.669)

 \sim 0.40

となります。

最終結果

平均をとってやればいいので、

 loss = \frac{l_1 + l_2}{2} = 0.36

となります。

プログラムで求める

nn. CrossEntropyを使う場合

ソースコード

import torch.nn as nn
import torch

# サンプルデータ
outputs = torch.tensor([[2.0, 0.5, 0.1], [0.5, 2.0, 0.7]])  # ロジット
labels = torch.tensor([0, 1])  # 正解ラベル

# 交差エントロピーを用意
criterion = nn.CrossEntropyLoss()

# 損失計算
loss = criterion(outputs, labels)
print(loss)

実行結果

tensor(0.3597)

となるため、手計算で求めた0.36とほぼ同じとなることが確認できました!

自作した場合

ライブラリを使わない場合は、次のように計算できます(結果は同じなので省略)。

自作関数

import numpy as np
import torch

# inputは(N, C)の行列とする
def calc_crossEntropyLoss(input, labels):

  # 入力の次元からNとCを取得
  N = input.size()[0]
  C = input.size()[1]
  
  # サンプルごとの損失を保存する配列
  loss = []
  for i in range(N):

    # 分母(i番目のサンプルに対して式()の分母を計算)
    bunbo = 0
    for j in range(C):
      bunbo += np.exp(input[i][j])

    # 分子(i番目のサンプルに対して式()の分子を計算)
    bunshi = np.exp(input[i][labels[i]])

    # 対数変換してi番目の損失を計算(式()のΣの中身を計算)
    ln = - np.log(bunshi / bunbo)
    loss.append(ln)

  # 損失の平均をとり、返り値とする
  loss_mean = torch.tensor(np.mean(loss))
  return loss_mean

呼び出し

# サンプルデータ
outputs = torch.tensor([[2.0, 0.5, 0.1], [0.5, 2.0, 0.7]])  # ロジット
labels = torch.tensor([0, 1])  # 正解ラベル

# 損失計算
loss = calc_crossEntropyLoss(outputs, labels)
print(loss)