카테고리 없음

ML/DL(3) - 손실 함수와 경사 하강법의 관계

dongsunseng 2023. 7. 8. 22:12

손실함수와 경사 하강법의 관계를 공부하며 헷갈리고 정확히 무슨말인지 이해가 안 가는 부분들이 있었는데 이들을 짚고 넘어가려 한다. 

1. 왜 가중치와 절편을 업데이트하는데에 손실함수를 미분한 값을 사용하는가

선형 회귀의 손실함수는 '제곱 오차(squared error)'로 예측값과 타깃값의 차이를 제곱한 것이다. 이때 제곱 오차가 최소가 되면 데이터의 경향을 가장 잘 표현하는 직선을 찾을 수 있는 것이다. 따라서 제곱 오차 함수의 최솟값을 알아내야 하는데 제곱 오차함수는 2차 함수이므로 기울기에 따라 함수의 값이 낮은 그래프의 최소에 가깝게 이동해야한다. 2차 함수인 이유는 아래와 같이 정리해보면 알 수 있다. 

출처:https://bskyvision.com/411

x축을 가중치 $w$ 혹은 절편 $b$로 두고 y축을 손실함수로 두었을 때 손실함수의 최소로 이동하려면 어느 방향으로 이동해야하는지를 알아야 하는데 이때 사용되는 값이 기울기이다. 기울기를 알면 어느 방향으로 움직여야 값이 증가 혹은 감소하는지를 알 수 있기 때문이다. 

출처:https://bskyvision.com/411

위의 이미지에 따라 손실함수의 최솟값에 다가가기 위해서는 $w$의 값이 점점 커져야 한다($b$의 경우도 마찬가지). w의 값이 커지는 과정을 일반화해서 수식으로 표현하기 위해서 손실함수를 각 가중치와 절편에 대하여 편미분을 해야하는 것이다. 

 

$w$ 편미분 결과: 

$$ \frac{\partial SE}{\partial w} = -2(y - \hat{y})x $$

 

손실함수에 상수를 곱하거나 나누어도 최종 모델의 가중치나 절편에 영향을 주지 않기 때문에 계산의 편의성을 위해서 아래와 같이 1/2를 없앤다.

 

$$ \frac{\partial SE}{\partial w} = -(y - \hat{y})x $$

 

이와 같이 우리는 가중치에 대한 제곱 오차의 변화율을 구한 것이고, 손실함수의 낮은 쪽으로 이동하기 위해서 $w$에서 변화율을 더하지 않고 뺀다. 이 부분에서 필자는 정확히 이해가 안 가는 느낌을 받았다.

2. 왜 가중치 업데이트를 할 때 w에서 가중치에 대한 제곱오차의 변화율을 더하지 않고 빼는가

식으로 풀어서 설명하면 아래와 같은 식이 우리의 최종 목표이다: w에 양수값을 더해야 w의 값이 커지는 업데이트를 할 수 있고 손실함수의 값이 작아질 수 있기 때문이다.

$$ w = w - \frac{\partial SE}{\partial w} = w + (y - \hat{y})x $$

 

$ -(y - \hat{y})x $는 결론적으로 음수이기 때문에 $w$에서 더하지 않고 빼서 양수값을 더할 수 있게 만드는 것이다. 이 값이 음수인 이유는 일단 예측값이 타깃값이 못 미치는 상황에서 가중치와 절편의 값을 높이며 최적값을 찾아가는 과정이라고 가정했기 때문에 $ y - \hat{y} $와 $x$ 값은 각각 양수이고 앞에 (-) 부호가 붙어있기 때문에  $ -(y - \hat{y})x $은 음수이다.

 

결론적으로 기울기값이 아래 이미지와 같이 w의 값이 커짐에 따라 편미분을 통해 구한 변화율의 값은 작아지고 이는 접선의 경사(기울기)를 뜻하므로 "경사하강법"이라는 이름이 붙은 것이다.

출처:https://bskyvision.com/411

혹시 제가 잘못 이해한 부분이 있다면 편하게 댓글 남겨주세요 :) 감사합니다. 

 

 

 

Nothing good ever comes from worrying or sitting there feeling sorry for yourself. Keep positive, keep pushing on and things will turn good.
- Conor Mcgregor -