交差エントロピーの計算(nn.CrossEntropy)を調べてみた
はじめに
以下の公式ドキュメントを読みながら本記事を書きました。
ある程度理解したつもりですが、間違っていたらすみません。
pytorch.org
交差エントロピー(nn.CrossEntropy)の計算式
考えやすくするため、本記事では重みづけは行っていません。
そのため、参考サイトの文字を一部変更して、
- N : サンプル数
- n : サンプル番号(1~N)
- C : クラス数
: nサンプル目の正解ラベル番号
: nサンプル目のラベルiのロジット
とすれば、n番目のサンプルの交差エントロピーは、
と計算されます。そして出力は全体の平均であるため、
が出力されます。
計算例
理論だけでは分かりづらいので、適当な例を挙げて計算していきます。
ここでは、ラベルが「リンゴ」、「バナナ」、「ブドウ」の順に1~3となっている、画像の分類をしたとし、その結果を以下に示します。

(正解ラベルの箇所を緑にしているため、)
n = 1について
表より、
だから、
となります(有効数字はガバガバですが...)。
n = 2について
先ほどと同様に表より、
だから、
となります。
最終結果
平均をとってやればいいので、
となります。
プログラムで求める
nn. CrossEntropyを使う場合
ソースコード
import torch.nn as nn import torch # サンプルデータ outputs = torch.tensor([[2.0, 0.5, 0.1], [0.5, 2.0, 0.7]]) # ロジット labels = torch.tensor([0, 1]) # 正解ラベル # 交差エントロピーを用意 criterion = nn.CrossEntropyLoss() # 損失計算 loss = criterion(outputs, labels) print(loss)
実行結果
tensor(0.3597)
となるため、手計算で求めた0.36とほぼ同じとなることが確認できました!
自作した場合
ライブラリを使わない場合は、次のように計算できます(結果は同じなので省略)。
自作関数
import numpy as np import torch # inputは(N, C)の行列とする def calc_crossEntropyLoss(input, labels): # 入力の次元からNとCを取得 N = input.size()[0] C = input.size()[1] # サンプルごとの損失を保存する配列 loss = [] for i in range(N): # 分母(i番目のサンプルに対して式()の分母を計算) bunbo = 0 for j in range(C): bunbo += np.exp(input[i][j]) # 分子(i番目のサンプルに対して式()の分子を計算) bunshi = np.exp(input[i][labels[i]]) # 対数変換してi番目の損失を計算(式()のΣの中身を計算) ln = - np.log(bunshi / bunbo) loss.append(ln) # 損失の平均をとり、返り値とする loss_mean = torch.tensor(np.mean(loss)) return loss_mean
呼び出し
# サンプルデータ outputs = torch.tensor([[2.0, 0.5, 0.1], [0.5, 2.0, 0.7]]) # ロジット labels = torch.tensor([0, 1]) # 正解ラベル # 損失計算 loss = calc_crossEntropyLoss(outputs, labels) print(loss)