인공지능/머신러닝

[머신러닝 - Python] 덧셈, 곱셈 노드 오차 역전파 구현 (Addition, Multiplication Back Propagation Implementation)

바보1 2022. 8. 13. 01:49
class MulLayer:
    # 곱셈 계층

    def __init__(self) -> None:
        self.x = None
        self.y = None


    def forward(self, x, y):
        # 순전파, x와 y의 값을 저장해야만 backward때 사용할 수 있다.
        self.x = x
        self.y = y
        out = x * y

        return out

    
    def backward(self, dout):
        # 역전파로 상위 계층에서의 미분 값 * 반대 노드의 값을 출력한다.
        dx = dout * self.y
        dy = dout * self.x

        return dx, dy


class AddLayer:
    # 덧셈 계층

    def __init__(self) -> None:
        pass


    def forward(self, x, y):
        # 순전파, x와 y 값을 저장하지 않아도 된다.
        out = x + y
        return out


    def backward(self, dout):
        dx = dout * 1
        dy = dout * 1

        return dx, dy

 

해당 구현의 나오는 이론은

https://hi-guten-tag.tistory.com/211

 

[머신러닝 - 이론] 오차 역전파, 오류 역전파 (Back Propagation)

인공지능을 공부하면서 가장 어려웠다고 생각하는 부분입니다. 이해하는 과정이 너무 어려웠고, 구글에 검색해도 죄다 중간 과정은 건너뛰고 결론만 써놓았더라고요.. 아무튼 이해하는데 쉽지

hi-guten-tag.tistory.com

참조하면 될 것 같습니다.