출간 : ICLR 2021
저자 : Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby
주제 : Image Classification
0. Abstract
본 논문은 이전 CNN 구조에서 특정 부분을 Transformer로 대체하여 만든 Vision Transformer(ViT)라는 모델입니다. image를 patch 단위로 나누어 Transformer에 통과시켜서 image classification task를 수행하게 되며, sota 모델들만큼 좋은 결과를 보여줌과 동시에 연산량을 매우 낮추었습니다.
1. Introduction
NLP task에선 Transformer 구조가 지배적인 위상을 갖고 있다는 내용들이 나옵니다.
그래서 본 기술의 핵심은 Transformer를 image에 적용시켜보자는 것입니다.
이를 위해 image를 patch 단위로 나누고, 이 patch들을 embedding 하여 transformer에 입력하는 방식을 사용합니다.
Transformer는 CNN에 비해서 inductive biases가 약합니다.
따라서 데이터의 양이 부족하면 translation equivariance나 locality 같은 속성들을 잘 잡지 못하여 모델이 잘 학습되지 않습니다.
* inductive bias : 쉽게 모델이 데이터의 규칙을 얼마나 잘 찾느냐(찾다 < 가정하다)라고 생각하면 될 것 같습니다.
하지만 일정 데이터에서 나오는 규칙은 일반적인(generative) 규칙이 아닐 수 있기 때문에 inductive bias가 너무 크면 overfitting 되기 쉽지 않을까 생각합니다.
오히려 transformer는 inductive bias가 약해서 문제라고 합니다.
본 논문에서 제시하는 inductive bias의 종류로는 translation equivariance와 locality가 있습니다.
CNN은 그 태생이 sequential한 데이터가 아니라 이미지를 지역적으로 보기 위해 사용되는 기술이기 때문에 위의 성질들을 갖고 있으나 transformer는 그렇지 않죠.
(translation equivariance : input의 변화에 맞추어 output이 잘 변화하는지, locality : 지역성)
자세한 내용은 다른 글을 참조 바랍니다.
그러나 data의 양이 많아지면 inductive bias를 이겨낸다고 합니다. 데이터의 양이 아~주 많을 때에요.
2. Relation work
pass
3. Method
ViT는 위의 그림처럼 이미지를 patch 단위로 나누고, class token 부분만을 MLP에 통과시켜서 loss를 산출합니다.
Transformer Encoder는 기존의 Transformer와 같은 구조입니다. position information은 특별한 수식 없이 embedding에 더해집니다.(왼쪽에 검은색 글씨로 표시되어 있습니다.)
3.1 Vision Transformer (ViT)
이미지 원본 데이터는 높이x넓이x채널(RGB) 의 모양입니다. 우선 이 이미지를 patch 단위로 나눈 뒤 각 patch들을 1차원으로 만듭니다.
patch의 크기를 P로 정한다면 위의 표기처럼 이미지는 N x (C x \(P^{2}\)) 의 모양이 됩니다.(파이토치 코드 기준으로 이미지 채널의 위치를 논문보다 왼쪽으로 옮겼습니다.) 이때 N은 이미지에서 나눈 patch의 총개수입니다.
이때 N개의 patch를 각각 D(latent vector size) 차원으로 embedding 합니다.
RGB 이미지라면 3 채널에서 D채널로 변경하는 것입니다.
따라서 patch embeddings를 통과한 이미지의 shape은 N x D x \(P^{2}\) 입니다.
그다음 class token을 추가합니다.
class token은 patch의 맨 앞부분에 추가시키기 때문에 \(z_{0}\)이라 명칭 하고, 이미지의 채널은 D + 1 이 됩니다.
patch 간의 position information을 추가하기 위해 position embedding 값을 더합니다.
1차원이어도 성능이 충분하다고 합니다.
embedding을 통과한 뒤 encoder는 다음과 같은 구성이라고 합니다.
Fig 1. 의 오른쪽 그림을 참조하시면 됩니다.
(1)은 embedding 과정이고, (2)~(4)는 transformer와 같은 구조로 보시면 됩니다.
(4)에서 y(class)는 class token만 마지막에 첫 번째 patch만 input으로 들어갑니다.
Inductive bias
CNN은 모델 전체에 걸쳐 inductive biases를 유지합니다. 하지만 ViT는 MLP layer는 그러한 특성을 유지하지만, self-attention layer는 글로벌한 특징을 감지한다고 하네요. 따라서 ViT는 대량의 데이터셋이 필요하게 됩니다.
3.2 Fine-Tuning and Higher Resolution
ViT는 우선 큰 데이터셋으로 학습을 한 뒤, 사용할 데이터를 가지고 맨 뒤에 head 부분만 fine-tuning을 하는 방식으로 학습을 했다고 합니다.
또한 patch size를 바꿀 수 있지만 재학습이 필요하게 되기 때문에 이미지를 interpolation 하여 patch size를 고정시키는 방향으로 진행했다고 합니다.
4. Experiment
4.2 Comparing to SOTA
JFT는 3억 장짜리 데이터셋입니다.(조금 반칙 수준의 데이터...)
JFT로 pre-trained 한 후 ImageNet으로 fine-tuning을 한 ViT-H/14 모델이 가장 높은 정확도를 보였네요.
Noisy Student는 JFT 데이터셋을 통해 Semi-supervised learning을 한 모델입니다. 이것도 상당히 높습니다.
여하튼 엄청 큰 dataset으로 pretraining 한 뒤 fine-tuning을 하면 ViT가 가장 높은 정확도를 보였습니다.
그리고 무엇보다 학습 속도가 엄청 빠릅니다.
데이터셋 종류에 따른 정확도인데, 각종 데이터셋에서 모두 가장 높은 정확도를 보였습니다.
4.3 Pre-training Data Requirement
위에서부터 큰 데이터셋이 필요하다고 계속 언급해왔습니다.
작은 데이터셋으로 학습할 때의 결과는 아래와 같습니다.
역시 ImageNet을 사용했을 때는 기존의 모델들에 비해 약간 성능이 떨어지네요.
그리고 데이터셋이 증가할수록 ViT의 성능이 상승하기 시작합니다.
4.4 Scaling Study
Appendix D
4.5 Inspecting Vision Transformer
ViT를 이해하기 위해선 ViT의 과정을 다시 한번 살펴봐야 합니다.
처음 layer에서 각각의 patch를 embedding 합니다.
Fig 7을 보면 각 patch마다 세밀한 구조들에 대한 기본 정보를 나타낸다고 하네요.
embedding을 하고 나서, position embedding을 더합니다.
이때 patch마다의 거리 정보를 입력받게 됩니다.
Fig 7의 가운데 이미지를 참조 바랍니다.
Self-attention이 가장 낮은(앞단의?) layer 정보도 참조할 수 있게 한다고 하네요.
Fig 7의 오른쪽을 보시면 attention distance는 앞에서 변동이 크고 뒤에선 일정합니다.
자세한 내용은 Appendix D에 있습니다.
4.6 Self-Supervision
Imagenet 기준으로 supervised와 4% 밖에 차이가 안 나게 되는 결과를 얻었습니다.
5. Conclusion
NLP task처럼 image를 sequence화 시켜서 Transformer를 도입하였습니다.
이때 CNN의 특성을 벗어나서 inductive biases가 부족할 수 있지만, large dataset을 이용하면 이러한 부분을 극복할 수 있습니다.
그리고 self-supervised 연구에 이바지하는 결과를 보여줬습니다.
code
참조 : https://github.com/jeonsworld/ViT-pytorch
class VisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.classifier = config.classifier
self.transformer = Transformer(config, img_size, vis)
self.head = Linear(config.hidden_size, num_classes)
def forward(self, x, labels=None):
x = self.transformer(x)
logits = self.head(x[:, 0])
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
return loss
else:
return logits, attn_weights
모델은 VisionTransformer라는 클래스로 시작합니다.
forward
x(이미지)는 바로 Transformer를 통과하고, 그 뒤에 class token인 head 부분만 Linear layer를 통과하여 logit 값을 산출합니다.
x.shape은 (batch size, channel, height, width) = (128, 3, 224, 224)의 모양을 가진 것으로 가정하겠습니다.
Transformer는 아래와 같습니다.
class Transformer(nn.Module):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config, vis)
def forward(self, input_ids):
embedding_output = self.embeddings(input_ids)
encoded = self.encoder(embedding_output)
return encoded
Transformer는 Embeddings와 Encoder로 이루어져 있습니다.
먼저 Embeddings입니다.
class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings.
"""
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
img_size = _pair(img_size)
patch_size = _pair(config.patches["size"])
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.dropout = Dropout(config.transformer["dropout_rate"])
def forward(self, x):
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)
if self.hybrid:
x = self.hybrid_model(x)
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
x = torch.cat((cls_tokens, x), dim=1)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
이런저런 옵션들을 빼고 나면 다음과 같이 간단합니다.
(hidden size=768, patch size=16)
forward 함수
1) x = self.patch_embedding(x)
먼저 patch embedding입니다.
(224, 224) 이미지에서 (16, 16) kernel, 16 stride로 이루어진 convolution layer를 통과하게 되면 (14, 14) 이미지가 나오게 됩니다. 이때 (14, 14)의 각 픽셀은 하나의 patch라고 볼 수 있습니다.
또한 output channel은 hidden size이기 때문에 768 channel을 가진 이미지가 됩니다.
최종적으로 patch embedding을 통과한 x는 (batch size, channel, height, width) = (128, 768, 14, 14) 입니다.
2) x = x.flatten(2)
그 다음 x.flatten(2)를 하면 patch 부분만 flatten이 되어서 (128, 768, 196)이 됩니다.
3) x = x.transpose(-1, -2)
transpose 하여 (128, 196, 768) 모양이 됩니다.
4) x = torch.cat((cls_tokens, x), dim=1)
class token은 nn.Parameter()를 통해 0으로 이루어진 parameter를 생성합니다.
class token은 embedding 값과 concat을 해야 하니 embedding과 shape이 같아야 합니다.
class token의 shape은 하나의 patch인 모양이 돼야 하므로 (batch size, 1, hidden size) = (128, 1, 768) 입니다.
class token이 맨 앞으로 오게 concat합니다. (128, 197, 768) 모양이 됩니다.
5) embeddings = x + self.position_embeddings
position embedding 값을 더합니다.
self.position_embeddings는 전체 이미지에 더해져야 하는 값이기 때문에 (1, 197, 768) 의 모양입니다.
덧셈을 하기 때문에 shape의 모양 변화는 없습니다.
최종적으로 Embeddings(x)의 shape은 (128, 197, 768)이 됩니다.
그다음 Encoder입니다.
class Encoder(nn.Module):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.layer = nn.ModuleList()
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
for _ in range(config.transformer["num_layers"]):
layer = Block(config, vis)
self.layer.append(copy.deepcopy(layer))
def forward(self, hidden_states):
for layer_block in self.layer:
hidden_states = layer_block(hidden_states)
encoded = self.encoder_norm(hidden_states)
return encoded
forward
1) hidden_states, weights = layer_block(hidden_states)
layer_block은 위의 그림에서 나온 Transformer 내부 구조입니다.(그냥 transformer입니다.)
자세한 내용은 참조한 github 코드에서 modeling.py의 Block 클래스를 보기 바랍니다.
2) encoded = self.encoder_norm(hidden_states)
마지막에 normalization을 한번 더 합니다. (모든 normalization은 LayerNorm입니다.)
다시 ViT 모델인 VisionTransformer입니다.
class VisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.classifier = config.classifier
self.transformer = Transformer(config, img_size, vis)
self.head = Linear(config.hidden_size, num_classes)
def forward(self, x, labels=None):
x = self.transformer(x)
logits = self.head(x[:, 0])
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
return loss
else:
return logits, attn_weights
forward
x는 이제 transformer를 통과한 x입니다.
logits = self.head(x[:, 0])
맨 앞의 채널만 Linear layer에 통과시켜 class를 뽑습니다.
그 후에 CrossEntropy로 loss를 산출합니다.