1. ReLU 함수
\(y = \left\{\begin{matrix}
x \,\, (x > 0) \\ 0 \,\, (x\leq 0)
\end{matrix}\right.\)
이므로 x에 대한 y의 미분은 다음과 같습니다.
\(\frac{\partial y}{\partial x} = \left\{\begin{matrix}
1 \,\, (x > 0) \\ 0 \,\, (x\leq 0)
\end{matrix}\right.\)
따라서 순전파 때 입력의 크기인 x가 0보다 크면 역전파는 상류의 값을 그대로 흘려보냅니다.
반면, 순전파 때 x가 0보다 작으면 역전파 때는 하류로 신호를 보내지 않습니다.
2. 구현
class Relu:
def __init__(self) -> None:
self.mask = None
def forward(self, x):
self.mask = (x <= 0)
out = x.copy()
out[self.mask] = 0
return out
def backward(self, dout):
dout[self.mask] = 0
dx = dout
return dx
x는 numpy array입니다.
이때 mask는 x와 동일한 차원의 크기를 가지게 되고, True, False만 들어갑니다.
역전파를 수행할 때는 입력할 때 x가 0이하인 값들을 0으로 만들고, 나머지 x가 0 초과였던 부분은 그대로 하류로 흘러보냅니다.
'인공지능 > 머신러닝' 카테고리의 다른 글
[머신러닝 - 이론] Linear Regression (선형 회귀) (0) | 2022.10.18 |
---|---|
[머신러닝 - Python] SIgmoid 계층 구현 (Sigmoid Class Implementation) (1) | 2022.08.20 |
[머신러닝 - Python] 덧셈, 곱셈 노드 오차 역전파 구현 (Addition, Multiplication Back Propagation Implementation) (0) | 2022.08.13 |
[머신러닝 - 이론] 오차 역전파, 오류 역전파 (Back Propagation) (0) | 2022.08.13 |
[머신러닝 - 이론] 수치 미분 (Numerical Differentiation) (0) | 2022.08.12 |