cross entropy error의 핵심은 정답에 해당하는 확률 값만 계산한다는 점입니다.
원-핫 인코딩 방식일 때의 코드를 보면
return -np.sum(target * np.log(pred + 1e-7)) / batch_size
라는 코드가 있는데, 이때 target을 log에 곱하는 모습을 볼 수 있습니다.
즉, 원-핫 인코딩 방식이므로 정답에 해당하는 인덱스만 1이고, 나머지 인덱스는 0입니다.
따라서 정답에 해당하는 인덱스를 k라고 한다면, pred[k]의 값만 log취해서 계산합니다.
1 * np.log(pred[k])가 되므로 k에 해당하는 에러율만 나타내고, 나머지 인덱스는 계산하지 않습니다.
마찬가지로 레이블 방식일 때를 보면,
return -np.sum(np.log(pred[np.arange(batch_size), target] + 1e-7)) / batch_size
이 또한 마찬가지로, batch_size가 10이라고 가정한다면 0, 1, 3, ... , 9 배열을 생성합니다.
따라서 pred[0, target[0]], ..., pred[9, target[9]]가 생성이 됩니다. (target 또한 array이므로 병렬화로 들어감)
그러므로 pred[0, target[0]]은 0번째를 예측한 값에서, target[0] = 실제 정답에 해당하는 인덱스의 확률을 가져와 log를 취하여 계산합니다.
둘 다 핵심은 정답에 해당하는 인덱스만 뽑아서 오류율을 계산한다는 점에 있습니다.
import os, sys
import numpy as np
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from mnist import load_mnist
def batch_cross_entropy_error(pred, target):
if pred.ndim == 1:
target = target.reshape(1, target.size)
pred = pred.reshape(1, pred.size)
batch_size = pred.shape[0]
# return -np.sum(target * np.log(pred + 1e-7)) / batch_size # 원-핫 인코딩 방식일 때
return -np.sum(np.log(pred[np.arange(batch_size), target] + 1e-7)) / batch_size # 레이블 방식일 때
if __name__ == "__main__":
(input_train, target_train), (input_test, target_test) = load_mnist(normalize=True, one_hot_label=False)
train_size = input_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size, replace = False)
input_batch = input_train[batch_mask]
target_batch = target_train[batch_mask]
pred_batch = np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]])
error = batch_cross_entropy_error(pred_batch, target_batch)
print(error)
2.302584092994546