照片由 Christopher Campbell 拍摄,来源 Unsplash
在人工智能的世界里,很少有创新能像生成对抗网络(GANs)那样吸引研究人员和创作者的想象力。2014年由Ian Goodfellow提出的GANs彻底改变了机器生成数据的方法,从超逼真的图像、合成音乐,甚至是深度伪造的视频。GANs的核心在于生成器和判别器这两个神经网络之间的精彩互动,在一场类似于博弈论的较量中,两者相互竞争,试图胜过对方。
如果你曾经好奇过AI是如何创造栩栩如生的艺术作品或把简单的草图变成照片般逼真的风景,这就是你所看到的GANs的力量。在这篇博客中,我们将会揭示GANs背后的秘密,并指导你如何使用PyTorch(一种最受欢迎的深度学习框架)实现GANs。无论你是AI爱好者还是刚入门的开发者,这个循序渐进的过程将帮助你掌握构建自己生成模型的知识。
参考来源:BBC Science Focus GANs的运作原理
GANs主要由两个神经网络构成:一个生成器(Generator)和一个判别器(Discriminator),这两个网络以类似博弈论的方式一起工作。
- 生成器:从随机噪声生成数据。它学习模仿真实数据的分布。
- 鉴别器:作为批评者,区分真实训练集中的数据和生成器生成的假数据。
这两个网络在竞争。
- 生成器改进,生成的数据看起来更真实。
- 判别器更擅长分辨真假。
这种对抗过程促使两个网络相互改进,最终生成器可以生成非常逼真的结果。
我们来看看这在人的脸上会怎样。
- 数据集:
你需要一个包含人脸的数据集,例如CelebA数据集,该数据集包含数千张名人的面部图像。
2. 初始化:
- 生成器 从生成随机噪声(例如,一组随机数值)开始。
- 判别器 用真实人脸(来自数据集)和生成器生成的假人脸来训练。
3. 培训:
- 生成器从噪声生成一张脸,然后将其传递给判别器。
- 判别器评估这张脸是真实的还是假的,并给出一个概率评分来反馈。
- 两个网络都会调整各自的参数:
- 生成器学习生成更难被识破的脸。
- 判别器提高其识别假脸的能力。
4. 对抗性游戏:
经过多次迭代,生成器逐渐变得越来越擅长生成逼真的面孔,而判别器则变得越来越严格。这个“游戏”一直进行,直到生成器生成的面孔与真实的面孔无异。
模型会逐渐学会数据集中的那些复杂的模式,比如:
- 形状:人脸的结构。
- 细节:眼睛、鼻子、嘴巴等。
- 质感和光影:肤色、发质和阴影效果。
例如,在训练初期,生成的面部可能会模糊不清或特征变形。随着训练的深入,这些面部会更加清晰逼真,反映出训练数据集中的多样性和特征。
GAN(生成对抗网络)广泛应用于:
- Deepfakes :深度伪造。
- 艺术创作与设计 :生成艺术品或虚拟形象。
- 数据扩充 :扩充训练其他模型的数据集。
- 修复与复原 :修复受损或低分辨率的照片。
通过用一个人脸数据集训练GAN,你可以生成全新且逼真的脸部,这些脸不属于任何真实存在的个体——这标志着生成模型的一大进步。
首先,让我们为项目新建一个环境,让我们开始吧!
导入相关库
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
然后创建一个类来加载数据集,
class CelebADataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
Args:
root_dir (string): 包含所有图片的文件夹。
transform (callable, 可选): 应用于图片的可选变换。
"""
self.root_dir = root_dir
self.transform = transform
# 获取所有图像文件路径
self.image_paths = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载图片
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
# 如果有变换,则应用该变换
if self.transform:
image = self.transform(image)
return image
这段代码定义了一个自定义的数据集类 CelebADataset
,用于加载和处理图像,尤其是特别在使用 PyTorch 提供的数据加载工具如 DataLoader
时非常有用。
# 定义转换(调整大小,裁剪,转换为张量,归一化)
transform = transforms.Compose([
transforms.Resize(64), # 将图像调整至64x64
transforms.CenterCrop(64), # 从中心裁剪至64x64
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化至[-1, 1]
])
# 从指定路径加载CelebA数据集(CelebA dataset)
dataset_path = r'C:\Users\Harish\Documents\Github\GAN\GAN_Tutorial\img_align_celeba'
dataset = CelebADataset(root_dir=dataset_path, transform=transform)
# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
# 查看总图像数量
print(f"总图像数量: {len(dataset)}")
接下来,我们要创建生成模型和鉴别模型。
# 生成器和判别器类(与之前所述相同)
class Generator(nn.Module): # 生成器类,用于生成图像
def __初始化__(self, z_dim=100, img_channels=3): # 初始化生成器
super(Generator, self).__初始化__()
self.model = nn.Sequential( # 顺序结构
nn.Linear(z_dim, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, img_channels * 64 * 64),
nn.Tanh() # 使用双曲正切激活函数
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 3, 64, 64) # 重塑为图像形式
return img
class Discriminator(nn.Module): # 判别器类,用于判别图像
def __初始化__(self, img_channels=3): # 初始化判别器
super(Discriminator, self).__初始化__()
self.model = nn.Sequential( # 顺序结构
nn.Flatten(), # 展平层
nn.Linear(img_channels * 64 * 64, 1024),
nn.LeakyReLU(0.2, inplace=True), # (原地操作,节省内存)
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True), # (原地操作,节省内存)
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True), # (原地操作,节省内存)
nn.Linear(256, 1),
nn.Sigmoid() # 使用Sigmoid激活函数
)
def forward(self, img):
return self.model(img) # 返回判别结果
1. 生成器类
生成器负责从随机噪声生成合成图像(例如,生成假图像)。在这种情况下是指图像。生成器就是从乱七八糟的数据里头变出假图像的。
不重要特点- 输入:一个大小为
z_dim
的随机生成的噪声向量z
。 - 输出:一个尺寸为3x64x64的合成图像,包含RGB三个通道,像素大小为64x64。
**nn.Linear(z_dim, 256)**
:
全连接层将噪声向量(维度为z_dim)映射至256维空间。
**nn.ReLU(True)**
:
使用ReLU激活函数来引入非线性,使网络能够模拟复杂的模式。
True
参数启用了原地操作,从而节省了内存。
**nn.Linear(256, 512)**
→**nn.Linear(512, 1024)**
→**nn.Linear(1024, img_channels * 64 * 64)**
:
全连接层逐步增加维度,将噪声向量转换为更适合生成图像的高维空间。
**nn.Tanh()**
:
最后的激活函数将输出图像的像素值调整到范围 [−1,1],这对于归一化的图像数据来说是常见的。
向前方法
**img = self.model(z)**
:
将噪声向量(z
)输入网络,得到代表像素值的扁平化输出结果。
**img = img.view(img.size(0), 3, 64, 64)**
:
将输出重塑为图像格式(批量大小,3,64,64),其中批量大小是输入的批次大小。
**img.size(0)**
:指的是批大小。
**3, 64, 64**
: 所需的图像尺寸(RGB格式,64x64像素)。
这样的合成图像的张量。
2. 判别器类.判别器通过输出概率将真实图像标记为1,将假图像标记为0,从而辨别真假图像。
主要特点-
输入:一个形状为(3,64,64)的图像张量。
- 输出:为每张图像输出一个概率值。
__init__
:初始化方法
目的是定义鉴别器模型的各个层次。
**nn.Flatten()**
:
将输入的图像张量(3,64,64)展平为大小为3×64×64(即12,288个值)的向量。
这样便于全连接层处理。
**nn.Linear(img_channels * 64 * 64, 1024)**
:
这里是一个全连接层,它将输入维度从12,288压缩到1,024。
**nn.LeakyReLU(0.2, inplace=True)**
:
LeakyReLU激活函数引入了非线性特性,允许负输入以0.2的斜率传递小梯度,以避免神经元因梯度消失而死亡。就地操作可以节省内存资源。
接下来的层次:
- 1024→512→256→1\text{1024} \to \text{512} \to \text{256} \to \text{1},维度逐步减少:全连接层逐步降低维度。
- 目的如下:捕捉多层次特征来区分真假。
**nn.Sigmoid()**
:
最后的激活函数输出一个介于 0 到 1 之间的概率,接近 1 的概率表示图像为真,接近 0 的概率表示假图像。
向前 函数:
目的说明:定义图像如何被判定为真实或虚假。
**self.model(img)**
:
处理输入图像并通过网络为每个图像生成概率值。
概率张量 ∈[0,1] 表示批次数据
它们是怎么工作的训练流程如下:
- 生成器 从随机噪声生成合成图片。
- 判别器 会判断真实图片和合成图片,并猜测哪张是真哪张是假的。
对抗性目标:
- 生成器学会生成更好的输出以“糊弄”辨别器。
- 辨别器学会更准确地辨别假图像。
损失函数 :
生成器:优化以使鉴别器预测假图是'真实'的概率最大化。
判别器:优化以正确判断输入的真假。
# 损失函数和优化器
adversarial_loss = nn.BCELoss()
generator = Generator(z_dim=100) # 噪声维度为100
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 生成器优化器
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 判别器优化器
# 定义设备(硬件)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 如果GPU可用,使用GPU,否则使用CPU
generator = generator.to(device) # 将生成器移动到指定设备
discriminator = discriminator.to(device) # 将判别器移动到指定设备
下面的代码片段设置了训练生成对抗网络(GAN)模型的损失函数、优化器,包括设备设置。
二元交叉熵损失函数用于比较预测的概率与实际的二元标签之间的差异。
def 保存生成的图片(generator, epoch, device, num_images=16):
z = torch.randn(num_images, 100).to(device)
生成的图片 = generator(z).detach().cpu()
网格 = torchvision.utils.make_grid(生成的图片, nrow=4, normalize=True)
plt.imshow(np.transpose(网格, (1, 2, 0)))
plt.title(f"Epoch {epoch}")
plt.axis('off')
plt.show()
# 训练过程
def 训练过程(generator, discriminator, dataloader, epochs=5):
for epoch in range(epochs):
for i, 图像 in enumerate(dataloader):
真实图像 = 图像.to(device)
批量大小 = 真实图像.size(0)
真实标签 = torch.ones(批量大小, 1).to(device)
假标签 = torch.zeros(批量大小, 1).to(device)
# 训练判别器
判别器优化器清零梯度()
真实样本损失 = adversarial_loss(discriminator(真实图像), 真实标签)
假样本损失 = adversarial_loss(discriminator(generator(torch.randn(批量大小, 100).to(device)).detach()), 假标签)
判别器损失 = (真实样本损失 + 假样本损失) / 2
判别器损失.backward()
判别器优化器更新梯度()
# 训练生成器
生成器优化器清零梯度()
生成器损失 = adversarial_loss(discriminator(generator(torch.randn(批量大小, 100).to(device))), 真实标签)
生成器损失.backward()
生成器优化器更新梯度()
if i % 50 == 0:
打印(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [判别器损失: {判别器损失.item()}] [生成器损失: {生成器损失.item()}]")
# 可选地,在每个epoch保存生成的图片
保存生成的图片(generator, epoch, device)
训练过程中,交替训练 Discriminator 和 Generator。Discriminator 通过最小化真实图像(带有真实标签)和假图像(带有假的标签)的损失来区分真实图像和假图像。Generator 通过让 Discriminator 错误地将假图像识别为真实的图像来生成逼真的假图像,从而最小化损失值。
这个对抗过程会被重复多个轮次,让两个模型使用各自的优化器进行更新。可以定期保存生成的图像来跟踪生成器的进步。
生成的图片:
第1纪元
第10纪元
我们需要训练多个epoch,让模型能够理解细节并像原图像那样再现。对于这篇文章,我停在了10个epoch。
如果你想支持这个作者,记得做以下事情:-
点赞,分享,关注我
- 👏 给这篇文章点 50 个赞,让它有机会被推荐
- 关注我 (Medium)
- 📰 在我的 Medium 个人主页上查看更多内容
- 🔔 在 LinkedIn 和 GitHub 上关注我