Basic Deep Learning/Dive into Deep Learning 리뷰

[D2L] 4.5 Weight Decay

needmorecaffeine 2022. 7. 11. 15:10

[ 이론 ]

4.4에서는 머신러닝 차원에서의 overfitting과 underfitting에 관해 다뤘다.

요약하자면 다음과 같다.

  • generalization error는 training error로 추정될 수 없으며 training loss만 최소화하고자 하는 것은 반드시 generalization error도 함께 최소화하지 않는다. 따라서 머신러닝 알고리즘은 overfitting에 주의해야 한다.
  • validation set을 통해 generalization error를 측정한다.
  • underfitting은 모델이 training error를 줄이지 못함을 말한다. 위 사진처럼 overfitting은 generalization loss가 training loss보다 클 경우를 말한다.
  • overfitting을 막기 위해서는 충분한 양질의 데이터를 사용해야 한다.

 

이러한 overfitting을 막기 위해서 사용하는 방법으로 소개되는 것이 weight decay이다. 데이터는 충분하다고 가정한 상태이다.

weight decay는 L2 norm regularization 이라는 다른 이름을 가지고 있고 머신러닝 모델에서 정규화를 위해 자주 사용되는 방법이다.

 

그러면 L1, L2 norm에 대해 더 자세히 알아보면 다음과 같다. 다음 포스팅에서 매우 설명을 잘 되어 있어 참고하였다. "빛나는 나무"

 

1. L1 & L2

우선 norm은 벡터의 크기를 측정하는 방법으로 두 벡터사이의 거리를 측정한다.

일반화한 Lp norm의 식은 다음과 같다.

p는 norm의 차수이고 p=1,2일 때를 각각 L1, L2 norm이라고 한다.

L1 norm은 다음과 같이 벡터의 각 원소들의 차이의 절대값의 합이다.

L2 norm은 두 벡터 사이의 유클리디안 거리이다. 다른 한 벡터가 0으로 구성되어 있다면 원점에서부터의 거리를 의미하기도 한다.

그 loss 값을 자세히 살펴보자. yi가 실제 값, f(xi)가 예측치이고 L1 loss는 다음과 같다.

L2 loss는 다음과 같다.

이런 계산 방식을 보면 알겠지만 L2 loss는 오차의 제곱을 더하기 때문에 outlier에 더 큰 영향을 받는다. 따라서 L1과 L2는 어떨 때 어떤 것을 쓰는거야 라는 질문에 일차적으로 답한다면 outlier에 대해 민감하게 다루고 싶다면 L2 loss를 사용하면 된다.

 

2. Regularization

즉 정규화는 위에서 언급했듯, overfitting을 막기 위한 방법이다. 모델의 목표가 loss function의 최소화만을 고려한다면, 즉 trainig error의 최소화만을 고려할 시 generalization error는 나빠진다. 그리고 training error를 계속해서 최소화한다면 특정 weight이 매우 큰 값을 가지게 된다. 그래서 loss function에 penalty term을 더한다. penalty term은 weight의 연산결과를 말하며 아래 L1, L2 에서 다르게 더해지며 자세한 내용은 아래에서 설명하겠다.

 

이렇게 penalty term을 더하는 이유는 방금 언급한 "minimizing the prediction loss on the training lables" 라는 목표 대신 "minimizing the sum of prediction loss and penalty term"을 목표로 하기 위함이다.이번 장에서는 여러 정규화 방법 중 L1, L2 regularization에 대해 살펴본다.

 

2.1 L1 Regularization

기존 loss function(cost function)에 추가된 항을 볼 수 있다. 이 때 lambda는 regularization constant로 non-negative hyperparameter이고 정규화의 정도를 결정하며 0에 가까울수록 정규화의 정도는 약하다. 이를 사용하는 regression은 Lasso Regression이다.

2.2 L2 Regularization

이번 포스팅의 메인인 L2 regularization이 wieght decay이다. 이는 ridge regression에 사용된다.

L1과 L2를 더 자세히 비교하자면 L1은 차이의 절대값의 합이기에 경우에 따라 다른 feature(벡테)가 같은 loss값을 가질 수 있다. 하지만 L2는 고유한 값을 가지게 된다.

때문에 L1 penalty는 여전히 weight의 값이 커지는 것에 대한 penalty 효과가 약하다. 그리고 그 연산으로 다른 weight이 0이 되는 경우가 있는데 이를 feature selection이라고 부른다. 반면 L2 penalty는 제곱합으로 penalty를 부여하기에 더 많은 weight들이 커지지 않게 할 수 있다. 추가적으로 제곱합 연산의 특성으로 미분 연산이 쉽다는 장점이 있어 gradient 연산도 용이하다.

 

이러한 이유로 현재 학습하는 딥러닝 내용에서는 weight decay, 즉 L2 regularization을 사용한다.

 


[ 실습 ]

그렇다면 weight decay가 실제 어떤 이점을 가지고 있는지 살펴보자.

다음의 polynomial dataset을 만들어 확인해보자.

weight decay의 이점을 파악하기 위해 데이터셋의 크기가 20개로 매우 작으면서 feature는 200개인 데이터셋을 만들었다.

# small training dataset, 200 dimensionality

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5 
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

그리고 initializing, L2 norm penalty, training 함수를 정의한다.

# Initializing 

def init_params():
  w = torch.normal(0,1, size = (num_inputs, 1), requires_grad = True)
  b = torch.zeros(1, requires_grad = True)
  return [w, b]
  
  
# L2 norm penalty

def l2_penalty(w):
  return torch.sum(w.pow(2))/2
  
# training loop

def train(lambd):
  w, b = init_params()
  net, loss = lambda X : d2l.linreg(X, w, b), d2l.squared_loss
  num_epochs, lr = 100, 0.003
  animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
  
  for epoch in range(num_epochs):
    for X, y in train_iter:
      # broadcast the the L2 norm penalty
      l = loss(net(X), y) + lambd * l2_penalty(w)
      l.sum().backward()
      d2l.sgd([w,b], lr, batch_size)
    if (epoch + 1)%5 == 0:
      animator.add(epoch+1, (d2l.evaluate_loss(net, train_iter, loss),
                             d2l.evaluate_loss(net, test_iter, loss)))
  print("L2 norm of w : ", torch.norm(w).item())

이제 weight decay의 이점을 확인해보자. 

람다 값을 0으로 설정해 weight decay가 구현되지 않은 경우다. 그래프에서 알 수 있듯이 test error는 거의 똑같이 유지되지만 train error는 급격히 감소하고 있음을 알 수 있다. overfitting이 일어난 것이다.

 

람다값을 3으로 설정해 weight decay를 구현했다.  구현하지 않았을 때와 비교해 test error도 꾸준히 감소하고 있음을 알 수 있다.

 

위 과정을 좀 더 빠르고 한번에 진행할 수 있는 코드라인을 마지막으로 이번 장을 마친다.

bias는 decay를 적용하지 않았다!

# specify weight decay hyperparameter directly through weight_decay
# not decay the bias (Pytorch decays both originally)

def train_concise(wd):
  net = nn.Sequential(nn.Linear(num_inputs, 1))
  for param in net.parameters():
    param.data.normal_()
  loss = nn.MSELoss(reduction = 'none')
  num_epochs, lr = 100, 0.003
  trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
  animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                          xlim=[5, num_epochs], legend=['train', 'test'])
  for epoch in range(num_epochs):
      for X, y in train_iter:
          trainer.zero_grad()
          l = loss(net(X), y)
          l.mean().backward()
          trainer.step()
      if (epoch + 1) % 5 == 0:
          animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                    d2l.evaluate_loss(net, test_iter, loss)))
  print('L2 norm of w:', net[0].weight.norm().item())