谦瑞
2023-08-22 03:23:25浏览 2868
1. 数据集
- CelebA数据集是一种用于人脸属性分析的大型数据集。该数据集包含超过20万个名人身份的人脸图像,每个人脸图像都带有40个不同的属性标签,包括年龄、性别、微笑等。
- CelebA数据集是由香港中文大学的计算机科学与工程学院(CUHK)创建的。它是一个广泛使用的数据集,被广泛用于人脸识别、人脸属性分析、人脸合成等相关研究领域。该数据集中的人脸图像来自互联网上的名人照片,包括电影明星、音乐家、运动员等。
- CelebA数据集中的人脸图像具有较大的变化,如姿势、表情、光照和背景等。这使得该数据集对于研究人脸属性分析的鲁棒性和准确性非常有价值。
- CelebA数据集还具有可扩展性,它提供了大量的图像样本和属性标签,可以用于深度学习等大规模训练和评估任务。
2. 重温DCGAN的结构
- 关于DCGAN的生成器和判别器,二者可以看作是一个相反的过程。
3. 程序实现
class Hyperparameters:
device = 'cpu'
data_root = 'D:/data'
image_size = 64
seed = 1234
z_dim = 100
data_channels = 3
batch_size = 64
n_workers = 2
beta = 0.5
init_lr = 0.0002
epochs = 1000
verbose_step = 250
save_step = 1000
HP = Hyperparameters()
from Gface.log.config import HP
from torchvision import transforms as T
import torchvision.datasets as TD
from torch.utils.data import DataLoader
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
data_face = TD.ImageFolder(root=HP.data_root,
transform=T.Compose([
T.Resize(HP.image_size),
T.CenterCrop(HP.image_size),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]),
)
face_loader = DataLoader(data_face,
batch_size=HP.batch_size,
shuffle=True,
num_workers=HP.n_workers)
invTrans = T.Compose([
T.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]),
T.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]),
])
if __name__ == '__main__':
import matplotlib.pyplot as plt
import torchvision.utils as vutils
for data, _ in face_loader:
print(data.size())
grid = vutils.make_grid(data, nrow=8)
plt.imshow(invTrans(grid).permute(1, 2, 0))
plt.show()
break
import torch
from torch import nn
from Gface.log.config import HP
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.projection_layer = nn.Linear(HP.z_dim, 4*4*1024)
self.generator = nn.Sequential(
nn.ConvTranspose2d(in_channels=1024,
out_channels=512,
kernel_size=(4, 4),
stride=(2, 2),
padding=(1, 1),
bias=False),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=512,
out_channels=256,
kernel_size=(4, 4),
stride=(2, 2),
padding=(1, 1),
bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=256,
out_channels=128,
kernel_size=(4, 4),
stride=(2, 2),
padding=(1, 1),
bias=False),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=128,
out_channels=HP.data_channels,
kernel_size=(4, 4),
stride=(2, 2),
padding=(1, 1),
bias=False),
nn.Tanh()
)
def forward(self, latent_z):
z = self.projection_layer(latent_z)
z_projected = z.view(-1, 1024, 4, 4)
return self.generator(z_projected)
@staticmethod
def weights_init(layer):
layer_class_name = layer.__class__.__name__
if 'Conv' in layer_class_name:
nn.init.normal_(layer.weight.data, 0.0, 0.02)
elif 'BatchNorm' in layer_class_name:
nn.init.normal_(layer.weight.data, 1.0, 0.02)
nn.init.normal_(layer.bias.data, 0.)
if __name__ == '__main__':
z = torch.randn(size=(64, 100))
G = Generator()
g_out = G(z)
print(g_out.size())
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from Gface.log.dataset_face import invTrans
grid = vutils.make_grid(g_out, nrow=8)
plt.imshow(invTrans(grid).permute(1, 2, 0))
plt.show()
import torch
from torch import nn
from Gface.log.config import HP
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.discriminator = nn.Sequential(
nn.Conv2d(in_channels=HP.data_channels,
out_channels=16,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=16,
out_channels=32,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=32,
out_channels=64,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=128,
out_channels=256,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
)
self.linear = nn.Linear(256*2*2, 1)
self.out_ac = nn.Sigmoid()
def forward(self, image):
out_d = self.discriminator(image)
out_d = out_d.view(-1, 256*2*2)
return self.out_ac(self.linear(out_d))
@staticmethod
def weights_init(layer):
layer_class_name = layer.__class__.__name__
if 'Conv' in layer_class_name:
nn.init.normal_(layer.weight.data, 0.0, 0.02)
elif 'BatchNorm' in layer_class_name:
nn.init.normal_(layer.weight.data, 1.0, 0.02)
nn.init.normal_(layer.bias.data, 0.)
if __name__ == '__main__':
g_z = torch.randn(size=(64, 3, 64, 64))
D = Discriminator()
d_out = D(g_z)
print(d_out.size())
import os
from argparse import ArgumentParser
import torch.optim as optim
import torch
import random
import numpy as np
import torch.nn as nn
from tensorboardX import SummaryWriter
from Gface.log.generator import Generator
from Gface.log.discriminator import Discriminator
import torchvision.utils as vutils
from Gface.log.config import HP
from Gface.log.dataset_face import face_loader, invTrans
logger = SummaryWriter('./log')
torch.random.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
def save_checkpoint(model_, epoch_, optm, checkpoint_path):
save_dict = {
'epoch': epoch_,
'model_state_dict': model_.state_dict(),
'optimizer_state_dict': optm.state_dict()
}
torch.save(save_dict, checkpoint_path)
def train():
parser = ArgumentParser(description='Model Training')
parser.add_argument(
'--c',
default=None,
type=str,
help='training from scratch or resume training'
)
args = parser.parse_args()
G = Generator()
G.apply(G.weights_init)
D = Discriminator()
D.apply(D.weights_init)
G.to(HP.device)
D.to(HP.device)
criterion = nn.BCELoss()
optimizer_g = optim.Adam(G.parameters(), lr=HP.init_lr, betas=(HP.beta, 0.999))
optimizer_d = optim.Adam(D.parameters(), lr=HP.init_lr, betas=(HP.beta, 0.999))
start_epoch, step = 0, 0
if args.c:
model_g_path = args.c.split('~')[0]
checkpoint_g = torch.load(model_g_path)
G.load_state_dict(checkpoint_g['model_state_dict'])
optimizer_g.load_state_dict(checkpoint_g['optimizer_state_dict'])
start_epoch_gc = checkpoint_g['epoch']
model_d_path = args.c.split('~')[1]
checkpoint_d = torch.load(model_d_path)
D.load_state_dict(checkpoint_d['model_state_dict'])
optimizer_d.load_state_dict(checkpoint_d['optimizer_state_dict'])
start_epoch_dc = checkpoint_d['epoch']
start_epoch = start_epoch_gc if start_epoch_dc > start_epoch_gc else start_epoch_dc
print('Resume Training From Epoch: %d' % start_epoch)
else:
print('Training From Scratch!')
G.train()
D.train()
fixed_latent_z = torch.randn(size=(64, 100), device=HP.device)
for epoch in range(start_epoch, HP.epochs):
print('Start Epoch: %d, Steps: %d' % (epoch, len(face_loader)))
for batch, _ in face_loader:
b_size = batch.size(0)
optimizer_d.zero_grad()
labels_gt = torch.full(size=(b_size, ), fill_value=0.9, dtype=torch.float, device=HP.device)
predict_labels_gt = D(batch.to(HP.device)).squeeze()
loss_d_of_gt = criterion(predict_labels_gt, labels_gt)
labels_fake = torch.full(size=(b_size, ), fill_value=0.1, dtype=torch.float, device=HP.device)
latent_z = torch.randn(size=(b_size, HP.z_dim), device=HP.device)
predict_labels_fake = D(G(latent_z)).squeeze()
loss_d_of_fake = criterion(predict_labels_fake, labels_fake)
loss_D = loss_d_of_gt + loss_d_of_fake
loss_D.backward()
optimizer_d.step()
logger.add_scalar('Loss/Discriminator', loss_D.mean().item(), step)
optimizer_g.zero_grad()
latent_z = torch.randn(size=(b_size, HP.z_dim), device=HP.device)
labels_for_g = torch.full(size=(b_size, ), fill_value=0.9, dtype=torch.float, device=HP.device)
predict_labels_from_g = D(G(latent_z)).squeeze()
loss_G = criterion(predict_labels_from_g, labels_for_g)
loss_G.backward()
optimizer_g.step()
logger.add_scalar('Loss/Generator', loss_G.mean().item(), step)
if not step % HP.verbose_step:
with torch.no_grad():
fake_image_dev = G(fixed_latent_z)
logger.add_image('Generator Faces', invTrans(vutils.make_grid(fake_image_dev.detach().cpu(), nrow=8)), step)
if not step % HP.save_step:
model_path = 'model_g_%d_%d.pth' % (epoch, step)
save_checkpoint(G, epoch, optimizer_g, os.path.join('model_save', model_path))
model_path = 'model_d_%d_%d.pth' % (epoch, step)
save_checkpoint(D, epoch, optimizer_d, os.path.join('model_save', model_path))
step += 1
logger.flush()
print('Epoch: [%d/%d], step: %d G loss: %.3f, D loss %.3f' %
(epoch, HP.epochs, step, loss_G.mean().item(), loss_D.mean().item()))
logger.close()
if __name__ == '__main__':
train()
import torch
from Gface.log.dataset_face import invTrans
from Gface.log.generator import Generator
from Gface.log.config import HP
import matplotlib.pyplot as plt
import torchvision.utils as vutils
G = Generator()
checkpoint = torch.load('./model_save/model_g_71_225000.pth', map_location='cpu')
G.load_state_dict(checkpoint['model_state_dict'])
G.to(HP.device)
G.eval()
while 1:
latent_z = torch.randn(size=(HP.batch_size, HP.z_dim), device=HP.device)
fake_faces = G(latent_z)
grid = vutils.make_grid(fake_faces, nrow=8)
plt.imshow(invTrans(grid).permute(1, 2, 0))
plt.show()
input()
- 到此,我们就训练了生成器和判别器,并完成了生成人脸照片的任务。