Basic Deep Learning/Dive into Deep Learning 리뷰

[D2L] 4.6 Dropout

needmorecaffeine 2022. 7. 11. 19:12

[ 이론 ]

overfitting을 막기 위한 방법으로 이전 장에서는 weight decay에 대해 배웠고 이번 장에서는 dropout에 대해 다뤄본다.

이전 장 weight decay의 핵심은 평균이 0인 가우시안 분포의 값을 가지는 weight을 가정하였고 model로 하여금 그 weight을 많은 feature들에 분산시키고자 하였다. 소수의 weight만 값이 커지는 문제를 방지하기 위해서이다.

overfitting은 보통 feature이 example보다 많을 때 발생한다고 알려져있다. 또한 다루었던 linear model은 feature간 interaction을 반영하지 못하는데 더 직관적으로 말하자면 positive, negative weight 값만을 규정할 뿐, context를 반영하지 못한다.

 

이를 bias-variance tradeoff로 확대할 수 있는데 generalizability와 flexibility를 둘 다 추구하기 힘들다는 것이다. linear model에서 이를 확인해보면 linear model은 큰 bias를 가지고 있어 일부 class of function만 표현할 수 있고 동시에 낮은 variance를 가지고 있어 데이터의 어떤 random sample에도 비슷한 결과를 가진다.

 

반면 neural network는 각각의 feature에 대해 따로 다루는 것이 아닌 group of feature간의 interaction을 표현한다. 위에서 말한대로라면 example이 feature보다 많으면 overfitting이 이뤄지지 않는 것 같지만 deep neural network에서는 overfitting이 발생할 수 있다. deep neural network는 매우 높은 flexibility, 즉 모든 데이터에 대해 설명할 수 있는 능력을 가지고 있기 때문이다.

 

이전 장의 weight decay regularization의 norm의 파라미터는 simplicity를 측정한다. 이를 다시 말하면 smoothness인데 function이 input의 작은 변화에 민감하게 반응하지 않는 것을 의미한다. 그래서 우리는 input 값에 random noise를 더하곤 했다. 

여기서 Dropout이 도입되게 되었다.

 

Dropout은 각각의 layer에 noise를 더하는 방법으로 더하는 그 타이밍은 그 다음 layer에 대한 연속 연산을 수행하기 전이다. 이렇게 noise를 inject하게 되면 smoothness를 확보할 수 있기 때문이다.

다시 정의하면 다음과 같다.

Dropout = inject noise while computing each internal layer during forward propagation

dropout은 말 그대로 training 중 몇몇의 neuron을 dropout하는 것으로 nodes의 일부를 zeroing하여 그 다음 연산을 수행하는 것이다. 이전 층의 activation이 다음 층의 activation과 관련되는 것이 co-adaptation이라고 하는데 이 co-adaptation을 깨는 것이 dropout의 목표이다.

 

그렇다면 noise는 어떻게 inject할 것인가?

한가지 방법은 unbiased 방식으로 noise를 더해 각 layer의 기댓값이 noise를 더하기 이전 layer의 값과 같게 하는 것이다.

dropout에 대해 도식화하면 다음과 같이 표현된다. 말그래도 몇개의 node를 버리는 것이다.

 


[ 실습 ]

가장 먼저 이번 장에서 배운 dropout layer부터 정의해본다.

def dropout_layer(X, dropout):
  assert 0 <= dropout <= 1
  if dropout == 1 :
    return torch.zeros_like(X) # drop all
  if dropout == 0: # all kept
    return X
  mask = (torch.rand(X.shape) > dropout).float()
  return mask * X / (1.0 - dropout)

위 연산에 대해 쪼개서 살펴보면 다음과 같다.

dropout을 training에 포함하는 과정이다. 이전에서 말했던 것 처럼, 다음 연속층으로 넘어가기 전에 dropout연산을 수행하고 droput은 training과정에서만 수행하는 것을 주의하자!

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

dropout1, dropout2 = 0.2, 0.5

class Net(nn.Module):
  def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2, is_training = True):
    super(Net, self).__init__()
    self.num_inputs = num_inputs
    self.training = is_training
    self.lin1 = nn.Linear(num_inputs, num_hiddens1)
    self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
    self.lin3 = nn.Linear(num_hiddens2, num_outputs)
    self.relu = nn.ReLU()

  def forward(self, X) : 
    H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
    # use dropout only when training
    if self.training == True:
      H1 = dropout_layer(H1, dropout1) # add dropout layer
    H2 = self.relu(self.lin2(H1))
    if self.training == True:
      H2 = dropout_layer(H2, dropout2) # add dropout layer
    out = self.lin3(H2)
    return out

net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction = 'none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr = lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

연산 결과이다. 이전 MLP에서 dropout 연산 없이 했던 결과보다 더 좋고 overfitting도 일어나지 않음을 알 수 있다.

위 과정을 API로 간단히 구현해보자.

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Dropout(dropout1),
                    nn.Linear(256, 256),
                    nn.Dropout(dropout2),
                    nn.Linear(256, 10))

def init_weights(m):
  if type(m) == nn.Linear:
    nn.init.normal_(m.weight, std = 0.01)

net.apply(init_weights)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)