출간 : ECCV 2020
저자 : Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
주제 : Unsupervised Image Translation (GAN)
paper : 2007.15651.pdf (arxiv.org)
- Abstract
I2I translation task에서 input의 content는 도메인에 대해 독립적입니다.
본 논문에선 contrastive learning을 기반으로 하여 mutual information을 maximizing 합니다.(서로 최대한 dependent하게 만듭니다.)
이미지의 비교는 patch-based로 이루어지며, NCE loss에 사용되는 negative label들은 다른 데이터를 끌어오지 않고, input data 내에서 모두 추출하여 학습을 합니다.
본 기술은 one-sided I2I translation으로 학습 시간을 줄이면서 성능은 제고시켰습니다.
* mutual information
\(I(x, y) = KL( p(x,y) || p(x)p(y) ) = H(x) - H(x|y) = H(y) - H(y|x)\)
: x와 y가 independent 할 수록 p(x,y)와 p(x)p(y)의 거리가 같아지기 때문에 정보량이 줄어듭니다.
: x와 y가 dependent 할 수록 p(x,y)와 p(x)p(y)의 거리가 커지기 때문에 정보량이 커집니다.
(서로 독립적인 x, y 일 때, p(x,y) = p(x)p(y) 입니다.)
1. Introduction
GAN의 제일 대표적인 이미지가 말과 얼룩말입니다. I2I translation task에서 우리는 기존 말의 형태를 유지하면서 무늬만 얼룩말 무늬로 바뀌길 원합니다. 이를 위해 cycle-consistency를 가정하여 학습하지만 cycle-consistency는 두 이미지가 bijective해야 한다는 강력한 제약 조건이 있습니다.(아래 사진처럼 같은 종류의 객체, 같은 포즈, 같은 크기인 경우는 극히 드물죠.) 이를 극복하기 위해 본 논문에선 patch-based로 input과 output의 mutual information을 maximizing하는 방법을 사용합니다.
InfoNCE loss를 사용하여 얼룩말(target)의 머리 부분과 비슷한 부분을 말(source)의 사진에서 찾아내어 연관시키고(positive), 관련 없는 부분은 negative 부분으로 하여 학습을 하게 됩니다.
negatives는 다른 이미지에서 끌어오지 않고 input image에서 추출합니다. 그래야 더 input image의 정보를 잘 보존할 수 있다고 하네요.
2. Related Works
cycle-consistency는 기본적으로 X와 F(G(X))가 서로 같아지도록 만드는 방법으로 현재 상당히 많이 쓰이는 방법입니다. 또 다양한 방법으로 변형해서 사용되고 있습니다. 그러나 bijection이라는 가정이 너무 제한적이기 때문에 완전한 reconstruct가 이뤄지기 힘들다고 하네요.
input image와 output image가 닮았다는 가정으로 학습을 하는 방법이 있습니다. 그러나 지금까지의 방법들은 전체 이미지를 다 사용하였고, predefined 된 distance들을 사용했습니다. 본 논문에선 특정 patches들로 닮은 정도를 찾아내고, predefined 된 distance가 아닌, information maximization을 하게 됩니다. 이를 cross-domain similarity function이라 부르네요.
pre-trained 된 모델(VGG 등등)을 활용해서 loss를 산출하는 perceptual loss라는 방법이 있습니다. pixel 단위로 loss를 구하는 것이 아니라 사람이 인식하는 것과 가까운 feature 단위로(pre-trained model의 output) loss를 구하는 방법입니다. 본 논문에선 특정 patch에서 cross-domain similarity function을 사용합니다.
Maximizing mutual information은 주로 NCE(noise contrastive estimation)를 사용한다고 합니다. pos와 neg를 나눠서 학습하는 방법입니다. 다양한 방법들이 있다네요.
3. Methods
Contrastive Unpaired Translation (CUT) 모델은 inverse auxiliary generator, discriminator가 없이 한 방향으로만 학습되는 방법입니다. (위에서 말한대로 cycle-consistency loss를 사용하지 않죠.) 따라서 모델 학습 시간이 줄어들게 됩니다.
일단 설명이 필요 없는 당연한 loss
그 다음 Mutual information maximization에 대한 설명입니다.
(이 아래부터가 본 논문의 핵심적인 Patch NCE loss에 대한 내용들입니다.)
mutual information을 maximize하기 위해 NCE(noise contrastive estimation)을 사용했습니다.
\(G_{enc}(x)\)의 feature에서 임의로 N개를 골라냅니다. (output 채널이 3이었다면 (N, 3)의 shape이 됩니다.)
그리고 \(G_{enc}(\hat{y})\)의 feature에서 위와 동일한 id에 있는 N개의 요소를 추출합니다.
추출한 요소들은 MLP( : \(H_{l})\)를 통과하게 되며,
\(H_{l}(G_{enc}(x))\)와 \(H_{l}(G_{enc}(\hat{y}))\)의 correspondence를 이용하여 positive와 negatives를 만들고, loss를 계산하게 됩니다.(아래 그림 참조)
correspondence를 증가시키고, non-correspondence를 감소시키면 두 이미지가 서로 의존적(dependent)이게 되니 mutual information은 커지게 됩니다.
Multilayer, patchwise contrastive learning.
여러 개의 layer를 뽑아내어 위에서 말한 maximize mutual information을 한다고 합니다.
input과 output의 correspondence를 비교해서 다리는 다리와 일치시킬 수 있고, 배경과 개체를 구분한다고 합니다.
\(L\) : multilayer의 개수
\(S_{l}\) : 임의로 골라낸 patch의 개수
\(C_{l}\) : MLP output channel
\(\{z_{l}\}_{L} = {H_{l}(G^{l}_{enc}(x))}_{L}\)
\(\{\hat{z_{l}}\}_{L} = {H_{l}(G^{l}_{enc}(G(x)))}_{L}\)
그래서 Loss 모양은 다음과 같지만 그냥 말 그림 보는게 더 직관적인 것 같네요.
- 원본 이미지와 만든 이미지를 Encoder에 통과시킨다.
- 몇 개의 레이어에서 나온 feature 값에서 sampling을 한다.
- sampling된 요소들을 MLP를 통과시킨다.
- MLP를 통과한 두 요소끼리 비교하여 pos와 neg를 만든다.
- cross entropy 계산.
Encoder는 domain-invariant concepts을 파악하고, Decoder는 domain-specific features를 파악한다고 합니다.
그리고 여러 이미지에서 negatives를 가져오는 것 보다 단일 이미지에서만 비교를 하는게 좋다고 하네요.
최종 목적함수입니다.
idt loss가 있는 경우와 없는 경우의 lambda의 수치가 다릅니다.
여느 모델보다 훨씬 간단한 loss 구조를 갖고 있습니다.
최종 정리.
0. Generator의 Encoder를 통과하는 과정에서 총 5개의 feature를 뽑아냅니다.
1. X에서 N(default = 256) 개의 요소를 sampling, G(X)에서 동일한 index N개를 sampling
2. 5개의 feature 각각 MLP H(X)를 통과합니다. (default output channel C = 256 --> \(H(G_{enc}(x))\)의 shape : (B, C, N)
3. positive patch : 두 matrix를 \(H(q)^{T}\) * H(k) 하여 각 차원의 correspondence를 추출. ( shape : (256,1,256) * (256, 256, 1) = (256, 1) )
4. negative patches : 두 matrix를 그대로 곱함. (256, 256) 차원이 됨. diagonal에 -10을 주어 correspondence 부분의 역할을 없애서 순수한 negative를 만듦.
5. 둘이 concat하여 (256,257)을 만듦.
6. cross-entropy loss (N+1 classfication)
# MLP
def forward(self, feats, num_patches=64, patch_ids=None):
### 0. Encoding한 input들
# feats : G(x)_l
return_ids = []
return_feats = []
if self.use_mlp and not self.mlp_init:
self.create_mlp(feats)
for feat_id, feat in enumerate(feats):
### 1. sampling
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
if num_patches > 0:
if patch_ids is not None:
patch_id = patch_ids[feat_id]
else:
patch_id = np.random.permutation(feat_reshape.shape[1])
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
else:
x_sample = feat_reshape
patch_id = []
### 2. MLP
if self.use_mlp:
mlp = getattr(self, 'mlp_%d' % feat_id)
x_sample = mlp(x_sample)
return_ids.append(patch_id)
x_sample = self.l2norm(x_sample)
if num_patches == 0:
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
return_feats.append(x_sample)
return return_feats, return_ids
# Patch NCE Loss
def forward(self, feat_q, feat_k):
num_patches = feat_q.shape[0] # num_patches = 256
dim = feat_q.shape[1] # dim = 256
feat_k = feat_k.detach()
### 3. pos logit
# (256, 1, 256) * (256, 256, 1)
l_pos = torch.bmm(
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
l_pos = l_pos.view(num_patches, 1)
### 4. neg logit
batch_dim_for_bmm = self.opt.batch_size # batch_size = 1
# reshape features to batch size
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
npatches = feat_q.size(1)
# (1, 256, 256) * (1, 256, 256)
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
# diagonal 값 무의미하게 만들기
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
l_neg_curbatch.masked_fill_(diagonal, -10.0)
l_neg = l_neg_curbatch.view(-1, npatches)
### 5. concat
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
### 6. loss
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
device=feat_q.device))
return loss
4. Experiments
4.1 Unpaired image translation
기본적으로 single image 단위, multi layer 학습, 그리고 unpaired dataset으로 학습했습니다.
Loss가 간단하기 때문에 속도와 메모리의 경량성이 우수합니다.
4.2 Ablation study and analysis
Ablation은 크게 3가지를 구분했습니다.
(a) negative sampling을 external에서 하는지 internal에서 하는지
(b) multi layer의 유무
(c) idt loss의 유무
결과
(a) internal에서만 sampling하는 것이 좋고,
(b) multi layer에서 학습하는 것이 좋고(차이가 매우 큼),
(c) idt loss가 있는 경우에 더 안정적인 학습이 가능합니다.(아래 그림 참조)
그리고 correpondence로 학습을 하는 것에 대한 증명입니다.
보시다시피 corresponding patches는 확실히 두 이미지간의 공통된 부분을 잘 찾아내고 있습니다.
4.3 High-resolution single image translation
고화질의 이미지를 16개의 작은 사이즈로 crob한 뒤 학습을 하며, discriminator는 특히 한번 더 사이즈를 crob해서 학습하도록 합니다. 이를 sinCUT 방법이라고 부르며 다른 방법들에 비해 좋은 결과를 보이네요.
5. Conclusion
핵심 : Maximize mutual information.
어떻게? corresponding patch를 학습시킨다.
external data를 사용하는 것 보다 이미지 자체에서 negatives를 가져오는 것이 더 효율적입니다.
다른 similarity loss를 계산하지 않기 때문에 간단함.
물론 2020년에 훌륭한 결과를 만들었습니다.
다음 논문은 DCLGAN.