저자 : Tamar Rott Shaham, Michaël Gharbi, Richard Zhang, Eli Shechtman, Tomer Michaeli
출간 : CVPR 2021
주제 : Supervised Image to Image (Image translation)
Framework : Pytorch
출처 : https://arxiv.org/pdf/2012.02992.pdf
Github : https://github.com/tamarott/ASAPNet
GitHub - tamarott/ASAPNet
Contribute to tamarott/ASAPNet development by creating an account on GitHub.
github.com
- 요약
ASAPNet 은 Generator의 신경망들을 Low-resolution에서 진행하여 모델의 속도를 매우 빠르게 만든 기술입니다. 본래 Feature map을 downsampling 할수록 이미지의 정보를 더 잃게 되지만, 본 논문에서는 HDRNet, Fourier-CPPNs 기술을 이용하여 low-resolution과 high-resolution을 매칭 시킴으로써 정보의 손실을 최소화합니다. 따라서 low-resolution을 generator에 통과시킴으로써 속도가 매우 빠르면서도, 막간에 high-resolution의 정보를 되찾으면서 high-resolution 이미지를 생성하게 됩니다. 결과적으로 pix2pixHD, SPADE 등의 모델들과 비슷한 품질의 이미지를 만들지만, 속도는 수 배가 더 빠른 모델입니다.
- 본문
1) Introduction
시작은 간단하게 GAN 모델의 속도가 실제 상황에 쓰여지기 힘들다는 내용들이 나오는데 가볍게 넘어갑니다. 다음으로 pixel-wise 기법을 사용하여 효율적인 모델을 만들었다는 내용이 나옵니다.
첫째로 MLP의 파라미터들은 spatial한 특징들을 함수를 통해 효율적으로 관리할 수 있고, 둘째로 그 함수가 CNN으로 downsample된 low-resolution과 연관이 있다고 하네요. 셋째로 MLP는 sinusoidal encoding을 통과한 값들을 사용한다고 합니다. CPPN에 관한 내용은 다른 글에서 찾아보기 바랍니다. (Fourier CPPNs 글 쓰고 나면 여기 링크)
다음 문단은 "어떤 종류의 이미지에서도 우리 모델은 효율이 좋다!" 입니다. 넘어가구요.
2) Related work
Convolutional image-to-image translation
요즘에 이미지 품질을 좋아지게 하는 방법들은 나오지만 속도에 대한 개선은 나오지 않는다고 하네요.
Adaptive neural networks
관련된 네트워크들을 여럿 소개를 합니다. 핵심은 본 논문에서 low-resolution을 통과시키는 CNN은 hypernetwork의 한 종류이며, 이전에 소개된 hypernetwork와 달리 더 지역적이고 속도가 최적화되었다고 합니다. 논문에서 제시된 hypernetwork 논문을 읽어보진 않았지만, CNN block에 instance normalization을 넣은 것이 큰 차이점일 것으로 추측됩니다. 마지막에는 encoding 합니다.
Trainable efficient image transformations
ASAPNet의 핵심은 무거운 연산은 low-resolution에서 진행하고, high-resolution에서 가벼운 연산을 통해 low-resolution과 통합하는 것입니다. 과거에 비슷한 연구를 소개하네요.
Functional image representations
feature mapping을 CPPNs를 이용하여 합니다. 기존의 CPPN과 다르게 픽셀 좌표값을 sinusoids로 인코딩하며, 이는 Fourier CPPNs와 상당히 흡사합니다.(그러나 Fourier Transform과 조금 다릅니다.)
Network optimization and compression
low-resolution에선 같은 acceleration을 사용하고, full-resolution에선 가벼우면서 highly parallelizable이라고 하는데 나중에 코드에서 보죠.
3) Method
이 기술은 input image가 들어오면 source image와 같은 형태로 바꾸어주는 image translation task입니다. (edge to iamge, segments to image 등등) HDRNet에서 영감을 받아 모델을 만들었다고 합니다. 모델은 크게 두 부분으로 나뉩니다. '3.1 Spatially Adaptive Pixelwise Networks' 은 high-resolution은 lightweight and highly parallelizable 한 layer를 사용하는 것에 대해 설명하고, '3.2 Predicting Pixelwise network parameters from a low-resolution input' 은 coarse resolution(low-resolution)에서 incompressible heavy computation을 설명한다고 합니다.
3.1) Spatially Adaptive Pixelwise Networks
high-resolution에서의 계산은 비싸므로 이를 줄이는 방법으로 spatially-varying pointwise nonlinear transformation(\(f _{p}\), p is pixel position)이란 방법을 사용합니다. (spatially-varying이란 말이 굉장히 자주 나오는데 딱 뭐다라고 말하긴 힘든 것 같네요. 대충 공간적이라는 말로 생각을 하면 될 것 같습니다. 저도 공부가 더 필요한 것 같네요.) pointwise의 장점으로 연산량을 줄이려 했지만, spatial 정보의 부재는 모델의 표현력을 떨어뜨리기에 본 논문은 이를 보완하는 기술들을 소개합니다. ASAPNet에선 spatial 한 정보를 보존하기 위해 두 가지 단계를 넣습니다. 각각의 pixelwise function인 \(f_p\) 는 pixel의 좌표 (p)와 색깔 값 \(x_p\)를 input으로 받습니다. 그다음 spatially-varying parameters \(\phi_{p}\)를 가진 MLP 5개를 통과합니다.
$$ f_{p}(x_{p}, p) = f(x_{p},p,;\phi_{p}) =: y_p $$
이를 통해 pointwise의 단점을 완화시킨 이미지가 산출됩니다. (좌표를 구하는 방법은 아래에서)
3.2) Predicting Pixelwise network parameters from a low-resolution input
\(\phi_p\)를 각각의 픽셀마다 독립적으로 구하는 건 상당히 많은 연산량을 필요로 합니다. 따라서 G(Generator)는 low-resolution 이미지를 input으로 갖습니다. 또 G의 output은 모든 high-resolution pixel에 대해 맞춰지기 위해 nearest neighbor interpolation으로 upsample 된 grid입니다.
$$ \phi_{p} = [G(x_{l}]_{[p/S]} $$
low-resolution은 원래 이미지보다 1/S배 작습니다. S는 \(S_1\)과 \(S_2\)로 나눠져 있습니다. \(S_1\)은 bilinear downsampling을 위한 인자이고, \(S_2\)는 convolution downsampling을 위한 인자입니다. 여기서 conv는 가로 세로 따로 나누어서 합니다. (연산량이 더 낮아지죠.)
3.3) Synthesizing high-resolution details using positional encodings
low-resolution에서 만들어진 파라미터로는 디테일을 살리기가 어렵습니다. 이를 보완하기 위해 Fourier CPPNs와 비슷한 기술을 사용하여 pixel을 encoding 합니다. (Fourier CPPNs 논문 리뷰) 기존의 input인 \(x_{p}\)채널에 (\(sin(2\pi p_{x}/2^{k}),cos(2\pi p_{x}/2^{k})\)) 채널을 추가로 넣습니다. ( for k = 1, ... \(\log_{2}(S)\), similiarly for \(p_{y}\) ) 이를 통해 이미지의 품질을 높입니다.
3.4) Training and implementation details
본 모델은 pix2pixHD에서 제안된 multi-scale patch discriminator를 사용합니다. 또 SPADE(gauGan)에서 나온 adversarial hinge-loss, perceptual loss, discriminator feature matching loss를 사용합니다. discriminator의 개수는 2개가 기본으로 설정되어있고, 각 discriminator는 5개의 conv를 갖고 있습니다. ( num_D = 2, num_intermediate_output=5)
adversarial hinge loss : \(\mathcal{L}_{adv} = \sum_{i=0}^{1} hinge(D_{i4}(G(x)))\) ( Discriminator(5개의 conv)를 통과한 마지막 output만 씁니다.)
feature matching loss : \(\mathcal{L}_{feat} = \sum_{i=0}^{1} \sum_{j=0}^{4} \lambda_{feat} \mid D_{ij}(G(x)), D_{ij}(x)\mid _{1}\)
perceptual loss (VGG loss) : \(\mathcal{L}_{per} = \sum_{i=0}^{4} \lambda_{i} \mid VGG_{i}(G(x)), VGG_{i}(x)\mid _{1} \)
(\(\lambda\) = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0])
$$ \mathcal{L}_{G} = \mathcal{L}_{adv} + \lambda_{feat} \mathcal{L}_{feat} + \lambda_{per} \mathcal{L}_{per} $$
$$ (\lambda_{adv} = \lambda_{feat} = 10) $$
$$ \mathcal{L}_{D} = \sum_{i=0}^{1} ( hinge(-D_{i4}(G(x)))) + hinge(D_{i4}(x)) ) $$
(논문에 나와있지 않아서 제가 코드를 직접 보고 쓴거라 틀린 부분이 있을 수 있습니다.)
4) Experimental
속도는 빠른데 정확도는 거의 비슷합니다.
4.2) Analysis
Beyond label - SPADE는 segmentic layout to image에만 최적화된 기술입니다. 그러나 ASAPNet은 다른 task에도 적용이 가능합니다.
Model ablations - 자기네들이 사용한 방법들이 훌륭하다네요.
Limitations - 아무래도 아주 작은 이미지에는 약합니다. 이미지의 3%보다 작은 사이즈는 잡기가 어렵다고 하네요.
5) Conclusion
이미지 품질은 현 기술들과 맞먹지만 연산량은 훨씬 줄어들었습니다. (나중에 리뷰하겠지만 Cocosnet-v2이란 새로운 모델이 본 논문에서 비교된 기술들보다 품질이 훨씬 뛰어나기 때문에 이젠 이미지 품질이 뛰어나다고 하긴 어렵겠네요.)
Code review
code는 기본적으로 SPADE의 구조를 받아왔습니다. cocosnet도 그렇고 spade의 구조가 supervised GAN모델에서는 많이 쓰이나봐요.
먼저 모델의 핵심인 models/pix2pix_model.py의 forward 부분으로 들어갑니다.
def forward(self, data, mode):
input_semantics, real_image = self.preprocess_input(data)
if mode == 'generator':
g_loss, generated = self.compute_generator_loss(
input_semantics, real_image)
return g_loss, generated
elif mode == 'discriminator':
d_loss = self.compute_discriminator_loss(
input_semantics, real_image)
return d_loss
elif mode == 'encode_only':
z, mu, logvar = self.encode_z(real_image)
return mu, logvar
elif mode == 'inference':
with torch.no_grad():
fake_image, _, _ = self.generate_fake(input_semantics, real_image)
return fake_image
else:
raise ValueError("|mode| is invalid")
compute_~~_loss 가 핵심이네요.
(VAE 옵션(encode_z_는 기본적으로 사용 안하는걸로 되어있습니다.)
가장 대표적인 compute_generator_loss를 들어가겠습니다.
def compute_generator_loss(self, input_semantics, real_image):
G_losses = {}
fake_image, lr_features, KLD_loss = self.generate_fake(
input_semantics, real_image, compute_kld_loss=self.opt.use_vae)
if self.opt.use_vae:
G_losses['KLD'] = KLD_loss
pred_fake, pred_real = self.discriminate(
input_semantics, fake_image, real_image)
if not self.opt.no_adv_loss:
G_losses['GAN'] = self.criterionGAN(pred_fake, True, for_discriminator=False)
if not self.opt.no_ganFeat_loss:
num_D = len(pred_fake)
GAN_Feat_loss = self.FloatTensor(1).fill_(0)
for i in range(num_D): # for each discriminator
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs): # for each layer output
unweighted_loss = self.criterionFeat(
pred_fake[i][j], pred_real[i][j].detach())
GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
G_losses['GAN_Feat'] = GAN_Feat_loss
if not self.opt.no_vgg_loss:
G_losses['VGG'] = self.criterionVGG(fake_image, real_image) \
* self.opt.lambda_vgg
return G_losses, fake_image
default에서 안쓰는 아랫부분은 짤랐어요. vae옵션도 기본적으론 안씁니다.
G_losses에 GAN, GAN_Feat, VGG 3개 있네요.
generate_fake는 다음과 같습니다.
def generate_fake(self, input_semantics, real_image, compute_kld_loss=False):
z = None
KLD_loss = None
if self.opt.use_vae:
z, mu, logvar = self.encode_z(real_image)
if compute_kld_loss:
KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
fake_image, lr_features = self.netG(input_semantics, z=z)
assert (not compute_kld_loss) or self.opt.use_vae, \
"You cannot compute KLD loss if opt.use_vae == False"
return fake_image, lr_features, KLD_loss
generate_fake 함수에서 netG에 input 집어넣습니다. netG는 models/networks/generator.py에서 ASAPNetsGenerator class를 불러온겁니다.
ASAPNetsGenerator class에서 forward를 보죠.
def forward(self, highres, z=None):
lowres = self.get_lowres(highres)
lr_features = self.lowres_stream(lowres)
output = self.highres_stream(highres, lr_features)
return output, lr_features#, lowres
1번째 줄 : low-resolution 으로 downsample
2번째 줄 : low-resolution -> Generator -> lr_features(논문에서 \(\phi_{p}\))
3번째 줄 : high-resolution 과 generator를 통과한 low-resolution을 mapping
쓰다보니 끝이 없을 것 같네요.
MLP 만들고 lr stream이랑 parameter 개수 맞춰서 계산하는 부분이 조금 어렵습니다.
앞으로 코드는 안해야겠어요.. 영상이 아니면 힘드네요.
지우기 아까우니까 그냥 남겨놔야징 ㅋㅋ
논문에서 정확한 과정이 나오지 않은 _get_coords(Fourier CPPNs과 비슷한 기술)는 이런 과정입니다.
def _get_coords(bs, h, w, device, ds, coords_type):
"""Creates the position encoding for the pixel-wise MLPs"""
if coords_type == 'cosine':
f0 = ds
f = f0
while f > 1:
x = th.arange(0, w).float()
y = th.arange(0, h).float()
xcos = th.cos((2 * pi * th.remainder(x, f).float() / f).float())
xsin = th.sin((2 * pi * th.remainder(x, f).float() / f).float())
ycos = th.cos((2 * pi * th.remainder(y, f).float() / f).float())
ysin = th.sin((2 * pi * th.remainder(y, f).float() / f).float())
xcos = xcos.view(1, 1, 1, w).repeat(bs, 1, h, 1)
xsin = xsin.view(1, 1, 1, w).repeat(bs, 1, h, 1)
ycos = ycos.view(1, 1, h, 1).repeat(bs, 1, 1, w)
ysin = ysin.view(1, 1, h, 1).repeat(bs, 1, 1, w)
coords_cur = th.cat([xcos, xsin, ycos, ysin], 1).to(device)
if f < f0:
coords = th.cat([coords, coords_cur], 1).to(device)
else:
coords = coords_cur
f = f//2
else:
raise NotImplementedError()
return coords.to(device)
이 과정은 나중에 Fourier-CPPNs를 읽고 리뷰해보겠습니다.