Chain Rule (calculus)
Scalar case
y=g(x) 그리고 z=f(g(x))=f(y) 인 경우 chain rule 은 다음과 같다.
dxdz=dydzdxdy
Vector case
크기가 n 인 vector x 가 있고, g,f 는 각각 입력과 출력이 n×k , k×m 이라고 하자.
∂x∂f(g(x))=∂g∂f∂x∂g
위 식은 Jacobian matrix 로 표현이 가능하다. 해당 행렬은 fi 를 gi 에 대하여 가능한 모든 조합과 gi 를 xi 에 대하여 가능한 모든 조합을 포함하고 있다.
∂x∂f(g(x))=∂g1∂f1∂g1∂f2∂g1∂fm∂g2∂f1∂g2∂f2∂g2∂fm⋯⋯⋯∂gk∂f1∂gk∂f2∂gk∂fm∂x1∂g1∂x1∂g2∂x1∂gk∂x2∂g1∂x2∂g2∂x2∂gk…⋯…∂xn∂g1∂xn∂g2∂xn∂gk
사실 행렬의 대각 원소가 아니면 ∂xj∂wi 에서 i=j 인 경우는 0 의 값을 가진다. 결과적으로 위 식은 다음과 같이 diagonal matrix 로 간소화될 수 있다.
∂x∂f(g(x))=diag(∂gi∂fi)diag(∂xi∂gi)=diag(∂gi∂fi∂xi∂gi)
예시
vector 내적 y=f(w)⋅g(x)=∑in(wixi)=sum(w⊗x) 에 대해서 dxdy 를 계산한다고 해보자. 이때, u=w⊗x 를 통해 치환하고, chain rule 을 활용하면 다음과 같다.
- dxdu=dxd(w⊗x)=diag(w)
- dudy=dudsum(u)=1T
- dxdy=dudy×dxdu=1T×diag(w)=wT
References