之前的章节中,我们介绍了多个图像分类实践,这次,我们PyTorch构建生成对抗网络进行图像生成。请从这里下载本章中需要用的动漫人脸数据集。
本篇也是PyTorch系列教程的最后一篇,在尾声之际,我希望可以用一个例子来总结PyTorch建模的“套路”。之所以用“套路”来形容PyTorch建模的过程,是因为PyTorch建模训练整一个完整的过程在我看来都是按照既定的流程进行:定义dataset-定义dataloader-定义网络模型-是否使用GPU-定义优化器、损失函数-迭代轮次-在每个轮次中:梯度清零-前向传播-计算loss-反向传播-测试集测试性能。这一系列所有章节都在围绕这个套路进行,无论是dataset还是优化器,还是性能评估,都是这个套路中的一环。我认为这一系列文章,最重要的,就是让这个“套路”深入每个读者的脑海,其次就是在实现这个“套路”时,能够对一起讨论过环节有些印象:定义数据及、定义网络结构、损失函数、迁移学习……这是这一系列内容的意义。最后一篇里,我思来想去,想介绍生成对抗网络,这个网络略有特殊,我就是想说明,就算是这种特殊的建模任务,也遵循着这个“套路”。
生成对抗网络(Generative Adversarial Networks,简称GAN)是最近几年一个非常热门的深度生成模型框架。GAN中的生成器(Generator)可以从随机噪声中生成假的数据,而判别器(Discriminator)则尝试区分生成的数据和真实数据的差异,这其实是一个分类模型,用于区分真实数据和生成数据。这两个网络相互对抗、一直在改进,最终生成器可以输出极为逼真的假数据。
GAN的思想源自“对抗”,其原理类似于两个人的博弈游戏。假设有一位伪造钞票的小偷,想要轻易骗过警察的眼睛;而另一方面,警察也在不断学习怎么样识破各种假钞票的特征。两者不断对抗,小偷的伪钞技术也在这样的一个过程中慢慢的提升,直到製作出极为逼真的假钞欺骗警察。GAN就用这种对抗的思想,让生成器和判别器一直在改进技术,最终得到高质量的生成结果。
具体来说,GAN包含两个神经网络:生成器和判别器。生成器的输入为随机噪声,输出为生成的数据。判别器的输入则为真实数据与生成数据,输出为每个输入数据的真伪概率。在训练中,首先固定生成器的参数,改进判别器的能力,使其尽可能将生成数据判断为假,将真实数据判断为真。然后固定判别器,改进生成器的参数,使其输出的数据可以欺骗判别器判断为真实数据。这样不断轮流训练,提升两个网络的对抗能力,生成器生成数据的质量也随之不断改善。
GAN没明确的损失函数,而是通过对抗的方式来促使模型训练。另一个独特之处是GAN的训练不需要任何标注或监督数据,能够最终靠大量未标注数据来进行训练。GAN已在多个领域取得了非常好的结果,如生成高清人脸和风景图片,图像超分辨率,风格迁移等。
接下来实现一个GAN来生成新的动漫脸,我希望可以通过这个例子更加深入、多角度地展示PyTorch建模“套路”。如果你对GAN网络思想原理还是不够理解,没关系,本篇的本来目标也只是借助GAN生成动漫人脸来实践PyTorch建模。
动漫人脸头像数据存放在当前目录下的“data/AnimeFaces/0”目录下,之所以存放在子目录“0”内,是为使用ImageFolder方便,避免自定义dataset。数据集内包含动漫人脸头像21551张。图像大小都是64*64.
img = img * std[:, None, None] + mean[:, None, None]images, labels = next(iter(dataloader))images = [unnormalize(img) for img in images]show_image([(img, ) for img in images[:10]]) # 只展示前10张
device = torch.device(cuda) if gpu else torch.device(cpu)
与普通图像分类任务中的网络模型不同,生成对抗网络指的是两个网络:生成器和判别器,也即是说,生成对抗网络中我们要定义两个网络,生成器用于生成虚假人脸,判别器用于对人脸图像真假进行甄别。两个网络的结构并没太多特殊的地方,如果非要说有,那就是生成器中要使用到反卷积层ConvTranspose2d,当然这个反卷积层在PyTorch也有提供,我们直接用即可。