AI TECH
CNN의 역전파
prefer_all
2022. 9. 22. 13:56
<이전에 간단히 배운 내용>
역전파 알고리즘을 통하여 gradient vector을 찾고, 각 층에 사용된 parameter을 학습함
역전파 알고리즘은 연쇄법칙 기반 자동미분을 사용함
초록색 화살표: 순전파
빨간색 화살표: 역전파
- ∂L/∂y 는 y에 대한 Loss의 변화량. 즉, 이를 정답과 비교하여 Loss를 구함.
- 우리의 목적은 뉴럴네트워크의 오차를 줄이는데 있기 때문에, 각 Parameter 별로 Loss에 대한 gradient를 구한 뒤 gradient들이 향하는 쪽으로 param을 업데이트함.
- ∂L/∂x 는 x에 대한 Loss의 변화량. 미분의 연쇄법칙을 사용함.
- ∂L/∂y 는 Loss로부터 흘러들어온 gradient이고, ∂y/∂x 는 현재 입력값에 대한 현재 연산결과의 변화량 (Local Gradient)임
- => 현재 입력값에 대한 Loss의 변화량은 Loss로부터 흘러들어온 gradient에 Local gradient를 곱해서 구한다
- 그리고 이 gradient는 다시 앞쪽에 배치되어 있는 노드로 역전파됨
실제로 역전파 과정을 살펴보자!
w1은 빨강, w2는 파랑, w3는 노랑, w4는 초록 (커널)
[역전파 사용해 input을 gradient로 표현하기] x22의 gradient는 흘러들어온 gradient d11에 local gradient(w4)를 곱해서 구할 수 있음. 마찬가지로 w4의 gradient는 흘러들어온 gradient d11에 local gradient(x22)을 곱해 계산함.
그런데 이런 식으로 하나하나 다 따져가면서 구하려면 번거로움
x11, x12,.. x33에 대해 직접 구하지 말고 간단하게 구할 수는 없을까?
그래서 커널이 흘러들어온 gradient 행렬(2*2 size)를 슬라이딩 하면서 값을 구한다는 아이디어 도입
이때, 커널 요소의 순서는 정반대로 (원래는 빨-파-노-초였는데 여기서는 초-노-파-빨)
위에서 x11은 w1*d11임을 확인했는데 위의 사진에서도 확인가능. 오른쪽 3*3 칸에서 d11에 빨간색으로 칠해져 있는 w1*d11 값임.
이때 필터의 gradient는 어떻게 구할까?
예를 들어, d11은 x11, x12, x21, x22와 연결되어 있음. 필터의 gradient는 흘러들어온 gradient(x11, x12, x21, x22)에 local gradient를 곱해서 구한다. local gradient는 합성곱 필터 가중치로 연결된 값이므로 dw11은 x11*d11 + x12*d12 + x21* d21 + x22* d22 이다.