经典GAN中的对抗loss
GAN中包含两种类型的网络G和D
cGAN中的条件对抗loss
pix2pix中的loss
最终的模型和CGAN有所不同,但它是一个CGAN,只不过输入只有一个,这个输入就是条件信息。原始的CGAN需要输入随机噪声,以及条件。这里之所有没有输入噪声信息,是因为在实际实验中,如果输入噪声和条件,噪声往往被淹没在条件C当中,所以这里直接省去了。
但是!!
高斯分布噪声可以为生成器带来生成的随机性,不然会确定性输出。在pix2pix中,实验发现z输入会淹没在确定性输出中,但是仍以dropout的形式加了随机性,在训练和测试时在生成器的的几层上加。但是还是只能观察到输出的轻微随机性。
代码部分
判别器的代码:
"""Calculate GAN loss for the discriminator"""
fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
real_AB = torch.cat((self.real_A, self.real_B), 1)
pred_real = self.netD(real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
判别器是条件判别器,fake_AB和real_AB都是结合的形式。
criterionGAN是选择了经典的vanilla GAN loss(the cross-entropy objective used in the orignal GAN paper)
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
生成器的代码:
"""Calculate GAN and L1 loss for the generator"""
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()
- 生成器是看G由real_A生成的fake_B的分数和True(即是1)的差距。
- 分数是由生成的fake_B和real_A结合起来给条件判别器打的分。
CycleGAN:
cycleGAN实现的是非配对的双域转换,即是数据集不需要是完全配准的数据集。
CycleGAN成功的原因在于它分离了风格(Style)和内容(content)。人工设计这种分离的算法是很难的,但有了神经网络,我们很容易让它学习者去自动保持内容而改变风格。
CycleGAN其实就是一个A→B单向GAN加上一个B→A单向GAN。两个GAN共享两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。
论文中的原图描述两个生成器G、F和两个判别器D_x和D_y的关系
- 分布X,分布Y,分别来自不同的domain
- G: X->Y ,生成器G来实现从X到Y的迁移
- F: Y->X ,生成器F来实现从Y到X的迁移
- Dx判别X与 F(y) ,判别器Dx判别到底是真X还是F根据Y生成的与X同分布的数据
- Dy判别Y与 G(x),判别器Dy判别到底是真Y还是G根据X生成的与Y同分布的数据
其过程包含了两种loss:
adversarial losses(对抗loss): 尽可能让生成器生成的数据分布接近于真实的数据分布
cycle consistency losses(循环一致性): 防止生成器G与F相互矛盾,即两个生成器生成数据之后还能变换回来近似看成X->Y->X
Adversarial loss
- 其中y表示域Y内的样本,x表示域X内的样本
- DY(y)表示真实的样本Y在判别器DY之中的评分,越接近1则判别器认为此样本越真
- G(x)为生成器根据x生成的与Y同分布的样本
- DY( G(x))为判别器根据生成的样本得到的评分
对于生成器G而言,相关联的只有DY( G(x))这一项,生成器的目标是希望生成的样本被判别器判以高分,即DY( G(x))这一项越大越好,但对于整个公式而言就是[1-DY( G(x))]越小越好。所以生成器会尽量最小化此loss,因此为minG。
对于判别器D而言,相关联的有DY(y)和DY( G(x))两项,判别器的目标是希望真实的样本y判高分,生成的样本G(x)判低分,即希望DY(y)越大越好,DY( G(x))越小越好,对于整个公式而言就是越大越好。所以判别器会尽量最大化此loss,因此为maxD。
即是这个最经典的对抗loss的形式:
cycle consistency loss
用于让两个生成器生成的样本之间不要相互矛盾。
Consistency loss 源域X中的图像x,经过其中一个生成器生成图像 G(x),作为另一个生成器的输入生成回来 F(G(x)),尽可能与原来图像接近
identity loss
在实验应用部分提到,论文原理部分中没有提,并且代码之中有涉及,是idt loss。
Identity loss 用于保证生成图像的连续性,一个图像y,经过其中一个生成器生成图像 G(y),尽可能与原来图像接近。就是让G生成的永远在目标域,因为y是目标域的(我们想要的),G是从源域到目标域的生成器,让这二者无限靠近就是说让G这个生成器生成的永远在目标域里面。
Consistency loss、identity loss 用于保证生成图像尽量保留源图像的信息。
total loss
手绘理解:
代码中的loss表示:
输出显示的图片:
代码部分
- G_A实现domainA到domainB迁移
- G_B实现domainB到domainA迁移
- D_A实现判别GA生成的数据fake_B or真实的B数据
- D_B实现判别GB生成的数据fake_A or真实的A数据
判别器的代码:
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
return loss_D
def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
从代码就可以看出DA判断的是real_B和fake_B; DB判断的是real_A, fake_A。
就是和我们下意识认为的不一样,要注意!!!
作者可能就是为了好看统一。本来应该是G_A(A)生成fake_B,应该是判别器D_B判别,但是统一成D_A(G_A(A))更好看吧。
判别器是经典GAN的判别器
pred_real = netD(real) 和 pred_fake = netD(fake.detach()) 都是单输入的
criterionGAN是选择了lsgan loss(the MSE loss)
生成器的代码:
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
if lambda_idt > 0:
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()
与pix2pix的异同
不同点
- 数据集
pix2pix是–dataset_mode aligned;
cycleGAN是–dataset_mode unaligned - 生成器G
pix2pix是 ‘–netG unet256’ U-Net generator;输入输出可以共享底层信息,而且可以有一部分信息不经过瓶颈层
cycleGAN是’–netG resnet_9blocks’ ResNet generator - 计算对抗loss的方式
pix2pix是’–gan_mode’ vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper);即nn.BCEWithLogitsLoss()
cycleGAN是 least-square GANs objective (’–gan_mode lsgan’); 即nn.MSELoss() - Dropout
pix2pix使用了dropout,因为生成器只使用了条件信息不能生成多样性的结果,因此pix2pix在训练和测试时都使用了dropout,这样可以生成多样性的结果
cycleGAN是Dropout is not used in the original CycleGAN paper. - 归一化层
pix2pix使用了batch normalization
pix2pix在测试时和常规操作的区别:
cycleGAN是instance normalization
- image buffer
pix2pix没使用
cycleGAN防止模型抖动,使用了ImagePool(opt.pool_size)#create image buffer to store previously generated images,在图片池中存了50张生成器之前生成的图片
相同点
- ‘–netD basic’ discriminator (PatchGAN)
参考文献
- https:///gdymind/article/details/826981
- https:///leviopku/article/details/81292192
- https:///weixin_374809/article/details/004841#1.2%20GAN%E5%88%86%E7%B1%BB
- https:///weixin_374809/article/details/88778213
- https:///weixin_374809/article/details/885136