인공지능/머신러닝

[머신러닝 - Python] 손실 함수 - 배치 교차 엔트로피 오차 구현 (Loss Function - Batch Cross Entropy Error Implementation)

바보1 2022. 7. 27. 19:16

 

cross entropy error의 핵심은 정답에 해당하는 확률 값만 계산한다는 점입니다.

원-핫 인코딩 방식일 때의 코드를 보면

return -np.sum(target * np.log(pred + 1e-7)) / batch_size

라는 코드가 있는데, 이때 target을 log에 곱하는 모습을 볼 수 있습니다.

즉, 원-핫 인코딩 방식이므로 정답에 해당하는 인덱스만 1이고, 나머지 인덱스는 0입니다.

 

따라서 정답에 해당하는 인덱스를 k라고 한다면, pred[k]의 값만 log취해서 계산합니다.

1 * np.log(pred[k])가 되므로 k에 해당하는 에러율만 나타내고, 나머지 인덱스는 계산하지 않습니다.

 

마찬가지로 레이블 방식일 때를 보면,

return -np.sum(np.log(pred[np.arange(batch_size), target] + 1e-7)) / batch_size

이 또한 마찬가지로, batch_size가 10이라고 가정한다면 0, 1, 3, ... , 9 배열을 생성합니다.

따라서 pred[0, target[0]], ..., pred[9, target[9]]가 생성이 됩니다. (target 또한 array이므로 병렬화로 들어감)

그러므로 pred[0, target[0]]은 0번째를 예측한 값에서, target[0] = 실제 정답에 해당하는 인덱스의 확률을 가져와 log를 취하여 계산합니다.

 

둘 다 핵심은 정답에 해당하는 인덱스만 뽑아서 오류율을 계산한다는 점에 있습니다.

 

 

import os, sys

import numpy as np
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from mnist import load_mnist


def batch_cross_entropy_error(pred, target):
    
    if pred.ndim == 1:
        target = target.reshape(1, target.size)
        pred = pred.reshape(1, pred.size)

    batch_size = pred.shape[0]
    # return -np.sum(target * np.log(pred + 1e-7)) / batch_size     # 원-핫 인코딩 방식일 때
    return -np.sum(np.log(pred[np.arange(batch_size), target] + 1e-7)) / batch_size     # 레이블 방식일 때


if __name__ == "__main__":
    (input_train, target_train), (input_test, target_test) = load_mnist(normalize=True, one_hot_label=False)

    train_size = input_train.shape[0]
    batch_size = 10
    batch_mask = np.random.choice(train_size, batch_size, replace = False) 

    input_batch = input_train[batch_mask]
    target_batch = target_train[batch_mask]
    pred_batch = np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
                           [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
                           [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
                           [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]])

    error = batch_cross_entropy_error(pred_batch, target_batch)

    print(error)
2.302584092994546