인공지능/머신러닝

[머신러닝 - 이론] 수치 미분 (Numerical Differentiation)

바보1 2022. 8. 12. 22:07

오차 역전파를 정리하려고 했는데, 그전에 수치 미분을 알아야 할 것 같아서 먼저 정리합니다.

 

앞선 글에서 설명했던 경사 하강법을 이용하여 에러를 최소화하기 위해서는 현재 상태에서의 기울기를 알아야 합니다.

정확히는 현재의 가중치 및 편차가 오류에 얼만큼의 영향을 끼치는지 알아야 합니다.

 

만약 손실함수가 \(f(x) = x^2\)이라면 미분을 통하여 \(2x\)를 통해 좌표 (0, 0)이 최소가 됨을 알 수 있습니다.

하지만 현실의 손실 함수는 저렇게 간단하지 않고, 또한 간단히 \(x\)를 통해서 나타나지도 않습니다.

 

따라서 보다 효율적이게 기울기를 찾아야 하는데, 이때 사용되는 것이 수치 미분과 오차 역전파입니다.


1. 수치 미분이란?

 

 

다들 아시겠지만, 미분이란 어느 한 점에서의 기울를 뜻합니다.

즉 한 점에서의 변화량을 뜻하고, 현재의 점의 변화가 결과에 얼만큼의 영향을 끼치는지 알 수 있습니다.

 

예를 들어 현재 속도가 3이라고 가정해봅시다. 이때 속도 함수를 미분하면 가속도 함수를 얻을 수 있습니다.

현재 상태에서의 가속도가 초당 5라면, 1초 후에 현재의 속도는 8이 됩니다.

즉, 미분을 하여 얻은 현재 상태의 변화량이 결과에 영향을 끼친 것이 됩니다.

 

그렇다면 경사 하강법을 사용하기 위해 필요한 것은 현재 상태가 얼만큼 결과에 영향을 끼치는지, 즉 기울기가 필요합니다.

고등학교 때 배운 것처럼 한 순간의 변화량을 얻는 방법은 되게 간단합니다.

이 수식이 끝이고, 수치 미분입니다.


2. 수치 미분이 어떻게 경사 하강법에 사용되는가?

 

 

위의 식에서 x에 특정한 값이 대입된다면 해당 점이 f(x) 함수 위에서의 기울기(= 한 순간의 변화량)를 나타냅니다.

 

경사 하강법에서는 f(x)가 손실 함수이고, x는 현재의 가중치나 편향이 됩니다.

(실제로는 x가 더 많이 얽혀있습니다.)

 

아무튼 간에 x가 가중치라고 가정하고, 미분을 통하여 계산하면 현재 손실 함수에 현재 가중치가 얼만큼의 영향을 끼치는지 알 수 있습니다.

 

따라서 학습의 순서를 설명하자면,

  1. 손실 함수를 통하여 오류를 계산한다.
  2. 수치 미분을 통하여 각각의 가중치와 각각의 편향이 얼마나 손실 함수(=오류)에 영향을 끼치는지 계산한다.
  3. 결론으로 나온 \(\frac{\partial L}{\partial w}\)이나, \(\frac{\partial L}{\partial b}\)를 통하여 오류가 최소가 되게 가중치와 편향을 업데이트한다. (= 경사 하강법에 근거하여 해당 손실 함수의 최솟값을 찾는다.)
  4. 즉 찾은 기울기를 통하여 \(w \leftarrow w - \eta * \frac{\partial L}{\partial w}\) 연산을 수행한다.

라고 볼 수 있겠습니다.

 

끝입니다. 간단하죠?

 

근데 여기서 치명적인 문제가 하나 있습니다.


3. 수치 미분 vs 오차 역전파

 

 

위의 학습의 순서에서 2번을 보면

수치 미분을 통하여 각각의 가중치와 각각의 편향이 얼마나 손실 함수(=오류)에 영향을 끼치는지 계산한다.

라고 나와있습니다.

 

해당 문장의 뜻을 세세하게 풀어보면, (여기서는 가중치만 설명하지만, 똑같이 편향에도 적용됩니다.)

  • 손실 함수의 결과는 오류이다.
  • 즉 손실 함수의 변화량은 오류의 변화량과 같다.
  • 오류를 알아내기 위해서는 입력 값과 타겟 값이 필요하고, 네트워크를 통해 입력 값에 대한 예상 값과 타겟 값의 오차를 계산한다.
  • \(f(w)\)는 현재 가중치에 대한 오류 값이고, \(f(w + h)\)는 현재 가중치에서 아주 작은 값을 더했을 때의 오류 값이다.
  • 따라서 \(\frac{f(w + h) - f(w)}{h}\)에서 h가 0으로 간다면, 현재 가중치가 얼마나 오류에 영향을 끼치는지 알 수 있다.
  • 하지만 \(f(w)\)라는 것은 현재 \(w\)에 대해서의 오류이므로 네트워크를 계산하여 예상 값을 알아내서 오류를 알아야 한다.
  • 그러므로 \(f(w + h), f(w)\)에서는 각각 한 번의 네트워크 계산이 필요하다.
  • 따라서 하나의 가중치에 총 두 번의 네트워크 계산이 필요하다.

 

이해가 가시나요??

 

하나의 가중치를 계산하는데도 두 번의 네트워크 계산이 필요합니다.

한 번 학습(= 1 epochs)을 하는데 가중치 하나, 편향 하나가 있다고 생각하면 학습을 위한 네트워크 계산은 총 4번을 하게 됩니다.

근데 가중치, 편향이 하나일까요?

또 epochs를 한 번만 할까요?

 

단순히 3층 네트워크라고 가정하고, 은닉층의 노드 개수가 10개라고 생각해봅시다.

따라서 가중치 뭉치는 3개, 편차는 3개가 존재하게 됩니다.

학습을 위한 epochs가 1000번이라고 생각한다면,

대충 계산해도 (1 * 2  + 1 * 2 + 1 * 2 + 1 * 2 + 1 * 2 + 1 * 2) * 1000 = 12,000번의 네트워크 계산이 필요합니다.

 

더 계산해본다면, 가중치는 행렬 곱을 통해서 값을 출력하고, 편향은 행렬 덧셈을 통하여 값을 출력합니다.

한 번의 네트워크 계산에 행렬 곱은 3번, 행렬 덧셈은 3번이 수행됩니다.

따라서 하나의 가중치를 학습하는데 행렬 곱은 6번, 행렬 덧셈도 6번입니다.

 

즉 위의 계산에서 행렬 곱은 72,000번, 행렬 덧셈도 72,000번 수행하게 됩니다.

쉽지 않죠?

 

따라서 수치 미분의 치명적인 단점은 속도가 너무 느리다입니다.

 

이를 보완하기 위하여 나온 것이 바로 오차 역전파입니다.

간단히 설명하자면 오차 역전파는 딱 두 번의 네트워크 계산(순전파, 역전파)으로 한 번의 학습을 완료합니다.

위의 예시에서 오차 역전파를 사용한다면 총 2,000번의 네트워크 계산만 필요합니다.

 

아무튼간에 다음 글에서는 오차 역전파에 대해 설명하도록 하겠습니다.

 

감사합니다.

 

 

 

지적 환영합니다.