GAN 介绍
自2014年第一篇关于GAN的论文提出,几年来,GAN的火爆程度不断增长,被广泛应用于视频、文本、图片的数据生成,取得了惊人的效果。下面这张图可以直观展示GAN可以干什么。
GAN模型原理概述
先上一张很经典的图
每个GAN包括两个模型
- Generator生成器。目的在于通过给定的数据(输入)进行数据生成(期望的数据),类似于在生产假币。
- Discriminator判别器。目的在于判断输入数据究竟是Generator生成的(假币)还是属于原始数据集(真币)。
GAN的训练就是这两个之间的对抗,Generator个想生成足以乱真的假币,Discriminator则要准确判断出到底是真币还是假币。在不断对抗中提升自己,这也是生成对抗网络名字的由来。
GAN数学原理
注意
GAN可以用到任何具有包含Generator和Discriminator的模型中,神经网络知识其中之一,这里只是采用神经网络来说明数学原理。
两个网络
- Generator: G(z,theat1) z 是网络输入,theat1是生成器对应网络的参数
- Discriminator:D(x,theta2) x是网络输入,theta2是判别器对应网络的参数,输出为概率,区间(0,1)
优化目的
根据上面原理概述所讲,
- 生成器对应网络优化的目的:提高D(G(z))
- 判别器对应网络优化的目的:提高D(x) ,降低D(G(z))
代价函数:
这里使用log的原因是,对数可以使得在使用梯度优化时,网络在分类错误的情况下获得很大的“惩罚”。参见下图,
GAN的训练
如果你对MBGD不太了解,参见这个链接
这里直接把论文中的图拿了过来
- 取出m个输入参数,{z1,......,zm}用于生成器生成数据
- 取出m个真实样本
- 更新判别器
- 将1-4重复k次
- 继续取出m个输入参数,{z1,......,zm}用于生成器生成数据
- 更新生成器
- 从1重复
编写一个GAN
1. 导入数据
数据经过预处理,转换成三个channel,然后归一化到(-1,1)
def minist_data():
compose = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]
)
output_dir = './dataset'
return datasets.MNIST(root=output_dir,train=True,transform=compose,download=True)
#load data
data = minist_data()
#create loader with data,so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data,batch_size=100,shuffle=True)
2.构造判别器
输入是一个图像,28*28*3,输出是一个概率,输出层激活函数是sigmod
#num_batches
num_batches = len(data_loader)
class DiscriminatorNet(torch.nn.Module):
"""
A three hidden-layer discriminative neural network
"""
def __init__(self):
super(DiscriminatorNet,self).__init__()
n_features = 784*3
n_out = 1
self.hidden0 = nn.Sequential(
nn.Linear(n_features,1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden1 = nn.Sequential(
nn.Linear(1024,512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.out = nn.Sequential(
nn.Linear(512,n_out),
nn.Sigmoid()
)
def forward(self, x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.out(x)
return x
discrimator = DiscriminatorNet()
3 写两个函数,进行图片和flatten images的转换,之后会用
def images_to_vectors(images):
return images.view(images.size(0),784*3)
def vectors_to_images(vectors):
return vectors.view(vectors.size(0),3,28,28)
4 构造生成器以及生成器输入
这里输入的数据为1*100的纬度
class GeneratorNet(torch.nn.Module):
def __init__(self):
super(GeneratorNet,self).__init__()
n_features = 100
n_out = 784*3
self.hidden0 = nn.Sequential(
nn.Linear(n_features,256),
nn.LeakyReLU(0.2),
)
self.hidden1 = nn.Sequential(
nn.Linear(256,512),
nn.LeakyReLU(0.2)
)
self.hidden2 = nn.Sequential(
nn.Linear(512,1024),
nn.LeakyReLU(0.2)
)
self.out = nn.Sequential(
nn.Linear(1024,n_out),
nn.Tanh()
)
def forward(self,x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.hidden2(x)
x = self.out(x)
return x
generator = GeneratorNet()
def noise(size):
n= Variable(torch.randn(size,100))
return n
5 构造优化器和损失函数
d_optimizer = optim.Adam(discrimator.parameters(),lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(),lr=0.0002)
loss = nn.BCELoss()
6 构造两个target
def ones_target(size):
data = Variable(torch.ones(size,1))
return data
def zeros_target(size):
data = Variable(torch.zeros(size,1))
return data
因为我们知道判别器的损失函数是
而我们用的损失函数是BCE也就是二分类交叉熵
这里我们将判别器的损失函数看成两个部分,第一部分就是加号前面的,作为Wi=1,vi = D(Xi),yi=1时的BCE,后面部分作为Wi=1,vi=D(G(Zi)),yi=0时的BCE,所以实际的loss看成是两个部分loss的和
7 训练分类器的算法
def train_discriminator(optimizer,real_data,fake_data):
N = real_data.size(0)
optimizer.zero_grad()
prediction_real = discrimator(real_data)
error_real = loss(prediction_real,ones_target(N))
error_real.backward()
prediction_fake = discrimator(fake_data)
error_fake = loss(prediction_fake,zeros_target(N))
error_fake.backward()
optimizer.step()
return error_fake+error_real,prediction_real,prediction_fake
8 训练生成器的算法
def train_genreator(optimizer,fake_data):
N = fake_data.size(0)
optimizer.zero_grad()
predicton = discrimator(fake_data)
error = loss(predicton,ones_target(N))
error.backward()
optimizer.step()
return error
9 开始训练
里面用到Logger类是一个自定义的可视化类,具体代码参见文末,直接复制到同级代码即可
num_test_samples = 16 #用于后面可视化观察训练过程
test_noise = noise(num_test_samples)
num_epoches = 200
logger = Logger(model_name="VGAN",data_name="MNIST")
for epoch in range(num_epoches):
for n_batch, (real_batch,_) in enumerate(data_loader):
N = real_batch.size(0)
real_data = Variable(images_to_vectors(real_batch))
fake_data = generator(noise(N)).detach()
d_error,d_pred_real,d_pred_fake = \
train_discriminator(d_optimizer,real_data,fake_data)
fake_data = generator(noise(N))
g_error = train_genreator(g_optimizer,fake_data)
# Log batch error
logger.log(d_error, g_error, epoch, n_batch, num_batches)
# Display Progress every few batches
if (n_batch) % 100 == 0:
test_images = vectors_to_images(generator(test_noise))
test_images = test_images.data
logger.log_images(
test_images, num_test_samples,
epoch, n_batch, num_batches
);
# Display status Logs
logger.display_status(
epoch, num_epoches, n_batch, num_batches,
d_error, g_error, d_pred_real, d_pred_fake
)
10 Logger代码,文件名utils.py
import os
import numpy as np
import errno
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
from matplotlib import pyplot as plt
import torch
'''
TensorBoard Data will be stored in './runs' path
'''
class Logger:
def __init__(self, model_name, data_name):
self.model_name = model_name
self.data_name = data_name
self.comment = '{}_{}'.format(model_name, data_name)
self.data_subdir = '{}/{}'.format(model_name, data_name)
# TensorBoard
self.writer = SummaryWriter(comment=self.comment)
def log(self, d_error, g_error, epoch, n_batch, num_batches):
# var_class = torch.autograd.variable.Variable
if isinstance(d_error, torch.autograd.Variable):
d_error = d_error.data.cpu().numpy()
if isinstance(g_error, torch.autograd.Variable):
g_error = g_error.data.cpu().numpy()
step = Logger._step(epoch, n_batch, num_batches)
self.writer.add_scalar(
'{}/D_error'.format(self.comment), d_error, step)
self.writer.add_scalar(
'{}/G_error'.format(self.comment), g_error, step)
def log_images(self, images, num_images, epoch, n_batch, num_batches, format='NCHW', normalize=True):
'''
input images are expected in format (NCHW)
'''
if type(images) == np.ndarray:
images = torch.from_numpy(images)
if format == 'NHWC':
images = images.transpose(1, 3)
step = Logger._step(epoch, n_batch, num_batches)
img_name = '{}/images{}'.format(self.comment, '')
# Make horizontal grid from image tensor
horizontal_grid = vutils.make_grid(
images, normalize=normalize, scale_each=True)
# Make vertical grid from image tensor
nrows = int(np.sqrt(num_images))
grid = vutils.make_grid(
images, nrow=nrows, normalize=True, scale_each=True)
# Add horizontal images to tensorboard
self.writer.add_image(img_name, horizontal_grid, step)
# Save plots
self.save_torch_images(horizontal_grid, grid, epoch, n_batch)
def save_torch_images(self, horizontal_grid, grid, epoch, n_batch, plot_horizontal=True):
out_dir = './data/images/{}'.format(self.data_subdir)
Logger._make_dir(out_dir)
# Plot and save horizontal
fig = plt.figure(figsize=(16, 16))
plt.imshow(np.moveaxis(horizontal_grid.numpy(), 0, -1))
plt.axis('off')
if plot_horizontal:
plt.show(plt.gcf())
self._save_images(fig, epoch, n_batch, 'hori')
plt.close()
# Save squared
fig = plt.figure()
plt.imshow(np.moveaxis(grid.numpy(), 0, -1))
plt.axis('off')
self._save_images(fig, epoch, n_batch)
plt.close()
def _save_images(self, fig, epoch, n_batch, comment=''):
out_dir = './data/images/{}'.format(self.data_subdir)
Logger._make_dir(out_dir)
fig.savefig('{}/{}_epoch_{}_batch_{}.png'.format(out_dir,
comment, epoch, n_batch))
def display_status(self, epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake):
# var_class = torch.autograd.variable.Variable
if isinstance(d_error, torch.autograd.Variable):
d_error = d_error.data.cpu().numpy()
if isinstance(g_error, torch.autograd.Variable):
g_error = g_error.data.cpu().numpy()
if isinstance(d_pred_real, torch.autograd.Variable):
d_pred_real = d_pred_real.data
if isinstance(d_pred_fake, torch.autograd.Variable):
d_pred_fake = d_pred_fake.data
print('Epoch: [{}/{}], Batch Num: [{}/{}]'.format(
epoch, num_epochs, n_batch, num_batches)
)
print('Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(d_error, g_error))
print('D(x): {:.4f}, D(G(z)): {:.4f}'.format(d_pred_real.mean(), d_pred_fake.mean()))
def save_models(self, generator, discriminator, epoch):
out_dir = './data/models/{}'.format(self.data_subdir)
Logger._make_dir(out_dir)
torch.save(generator.state_dict(),
'{}/G_epoch_{}'.format(out_dir, epoch))
torch.save(discriminator.state_dict(),
'{}/D_epoch_{}'.format(out_dir, epoch))
def close(self):
self.writer.close()
# Private Functionality
@staticmethod
def _step(epoch, n_batch, num_batches):
return epoch * num_batches + n_batch
@staticmethod
def _make_dir(directory):
try:
os.makedirs(directory)
except OSError as e:
if e.errno != errno.EEXIST:
raise
11 完整代码
import torch
from torch import nn,optim
import torchvision
from torch.autograd.variable import Variable
from torchvision import transforms,datasets
from utils import Logger
def minist_data():
compose = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]
)
output_dir = './dataset'
return datasets.MNIST(root=output_dir,train=True,transform=compose,download=True)
#load data
data = minist_data()
#create loader with data,so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data,batch_size=100,shuffle=True)
#num_batches
num_batches = len(data_loader)
class DiscriminatorNet(torch.nn.Module):
"""
A three hidden-layer discriminative neural network
"""
def __init__(self):
super(DiscriminatorNet,self).__init__()
n_features = 784*3
n_out = 1
self.hidden0 = nn.Sequential(
nn.Linear(n_features,1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden1 = nn.Sequential(
nn.Linear(1024,512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.out = nn.Sequential(
nn.Linear(512,n_out),
nn.Sigmoid()
)
def forward(self, x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.out(x)
return x
discrimator = DiscriminatorNet()
def images_to_vectors(images):
return images.view(images.size(0),784*3)
def vectors_to_images(vectors):
return vectors.view(vectors.size(0),3,28,28)
class GeneratorNet(torch.nn.Module):
def __init__(self):
super(GeneratorNet,self).__init__()
n_features = 100
n_out = 784*3
self.hidden0 = nn.Sequential(
nn.Linear(n_features,256),
nn.LeakyReLU(0.2),
)
self.hidden1 = nn.Sequential(
nn.Linear(256,512),
nn.LeakyReLU(0.2)
)
self.hidden2 = nn.Sequential(
nn.Linear(512,1024),
nn.LeakyReLU(0.2)
)
self.out = nn.Sequential(
nn.Linear(1024,n_out),
nn.Tanh()
)
def forward(self,x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.hidden2(x)
x = self.out(x)
return x
generator = GeneratorNet()
def noise(size):
n= Variable(torch.randn(size,100))
return n
d_optimizer = optim.Adam(discrimator.parameters(),lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(),lr=0.0002)
loss = nn.BCELoss()
def ones_target(size):
data = Variable(torch.ones(size,1))
return data
def zeros_target(size):
data = Variable(torch.zeros(size,1))
return data
def train_discriminator(optimizer,real_data,fake_data):
N = real_data.size(0)
optimizer.zero_grad()
prediction_real = discrimator(real_data)
error_real = loss(prediction_real,ones_target(N))
error_real.backward()
prediction_fake = discrimator(fake_data)
error_fake = loss(prediction_fake,zeros_target(N))
error_fake.backward()
optimizer.step()
return error_fake+error_real,prediction_real,prediction_fake
def train_genreator(optimizer,fake_data):
N = fake_data.size(0)
optimizer.zero_grad()
predicton = discrimator(fake_data)
error = loss(predicton,ones_target(N))
error.backward()
optimizer.step()
return error
num_test_samples = 16
test_noise = noise(num_test_samples)
num_epoches = 200
logger = Logger(model_name="VGAN",data_name="MNIST")
for epoch in range(num_epoches):
for n_batch, (real_batch,_) in enumerate(data_loader):
N = real_batch.size(0)
real_data = Variable(images_to_vectors(real_batch))
fake_data = generator(noise(N)).detach()
d_error,d_pred_real,d_pred_fake = \
train_discriminator(d_optimizer,real_data,fake_data)
fake_data = generator(noise(N))
g_error = train_genreator(g_optimizer,fake_data)
# Log batch error
logger.log(d_error, g_error, epoch, n_batch, num_batches)
# Display Progress every few batches
if (n_batch) % 100 == 0:
test_images = vectors_to_images(generator(test_noise))
test_images = test_images.data
logger.log_images(
test_images, num_test_samples,
epoch, n_batch, num_batches
);
# Display status Logs
logger.display_status(
epoch, num_epoches, n_batch, num_batches,
d_error, g_error, d_pred_real, d_pred_fake
)
11 训练过程可视化图
最开始
之后,不断改进
最后的结果
生成器错误率随时间变换图
判别器错误率随时间变换图
可以看到,最后判别器的错误率较高,生成器错误率较低。