GAN基础学习并编写一个简单模型

254

GAN 介绍

        自2014年第一篇关于GAN的论文提出,几年来,GAN的火爆程度不断增长,被广泛应用于视频、文本、图片的数据生成,取得了惊人的效果。下面这张图可以直观展示GAN可以干什么。


        GAN全称,Generative Adversarial Networks(生成对抗网络),名字来源于算法实现原理(后文会讲到),在机器学习的大家族中,隶属于unsupervised learning(非监督学习)生成模型分支的一种,生成模型的目的是学习给定训练数据的内在分布函数以生成数据。与supervised learning(监督学习)不同,监督学习会给定标签,(举个例子,分类问题中,会提前说明或者指定分成几类,每个类就是标签),所以监督学习学习的是数据的pdf,即条件概率。

      一个继承了许多有趣的GAN应用的网址


GAN模型原理概述

         先上一张很经典的图



每个GAN包括两个模型

  • Generator生成器。目的在于通过给定的数据(输入)进行数据生成(期望的数据),类似于在生产假币。
  • Discriminator判别器。目的在于判断输入数据究竟是Generator生成的(假币)还是属于原始数据集(真币)。

GAN的训练就是这两个之间的对抗,Generator个想生成足以乱真的假币,Discriminator则要准确判断出到底是真币还是假币。在不断对抗中提升自己,这也是生成对抗网络名字的由来。

GAN数学原理

注意

GAN可以用到任何具有包含Generator和Discriminator的模型中,神经网络知识其中之一,这里只是采用神经网络来说明数学原理。

两个网络

  • Generator: G(z,theat1) z 是网络输入,theat1是生成器对应网络的参数
  • Discriminator:D(x,theta2) x是网络输入,theta2是判别器对应网络的参数,输出为概率,区间(0,1)

优化目的

   根据上面原理概述所讲,

  • 生成器对应网络优化的目的:提高D(G(z))
  • 判别器对应网络优化的目的:提高D(x) ,降低D(G(z))

代价函数:  


这里使用log的原因是,对数可以使得在使用梯度优化时,网络在分类错误的情况下获得很大的“惩罚”。参见下图,


GAN的训练

如果你对MBGD不太了解,参见这个链接

这里直接把论文中的图拿了过来


文字描述如下

  1. 取出m个输入参数,{z1,......,zm}用于生成器生成数据
  2. 取出m个真实样本
  3. 更新判别器
  4. 将1-4重复k次
  5. 继续取出m个输入参数,{z1,......,zm}用于生成器生成数据
  6. 更新生成器
  7. 从1重复

编写一个GAN

1. 导入数据

数据经过预处理,转换成三个channel,然后归一化到(-1,1)

def minist_data():
    compose = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]
    )
    output_dir = './dataset'
    return datasets.MNIST(root=output_dir,train=True,transform=compose,download=True)

#load data
data = minist_data()

#create loader with data,so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data,batch_size=100,shuffle=True)

2.构造判别器

输入是一个图像,28*28*3,输出是一个概率,输出层激活函数是sigmod

#num_batches
num_batches = len(data_loader)

class DiscriminatorNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super(DiscriminatorNet,self).__init__()
        n_features = 784*3
        n_out = 1

        self.hidden0 = nn.Sequential(
            nn.Linear(n_features,1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.hidden1 = nn.Sequential(
            nn.Linear(1024,512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.out = nn.Sequential(
            nn.Linear(512,n_out),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.out(x)
        return x

discrimator = DiscriminatorNet()

3 写两个函数,进行图片和flatten images的转换,之后会用

def images_to_vectors(images):
    return images.view(images.size(0),784*3)
def vectors_to_images(vectors):
    return vectors.view(vectors.size(0),3,28,28)

4 构造生成器以及生成器输入

这里输入的数据为1*100的纬度

class GeneratorNet(torch.nn.Module):
    def __init__(self):
        super(GeneratorNet,self).__init__()
        n_features = 100
        n_out = 784*3

        self.hidden0 = nn.Sequential(
            nn.Linear(n_features,256),
            nn.LeakyReLU(0.2),
        )

        self.hidden1 = nn.Sequential(
            nn.Linear(256,512),
            nn.LeakyReLU(0.2)
        )

        self.hidden2 = nn.Sequential(
            nn.Linear(512,1024),
            nn.LeakyReLU(0.2)
        )

        self.out = nn.Sequential(
            nn.Linear(1024,n_out),
            nn.Tanh()
        )

    def forward(self,x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

generator = GeneratorNet()

def noise(size):
    n= Variable(torch.randn(size,100))
    return n

5 构造优化器和损失函数

d_optimizer = optim.Adam(discrimator.parameters(),lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(),lr=0.0002)

loss = nn.BCELoss()

 6 构造两个target

def ones_target(size):
    data = Variable(torch.ones(size,1))
    return data

def zeros_target(size):
    data = Variable(torch.zeros(size,1))
    return data

因为我们知道判别器的损失函数是


而我们用的损失函数是BCE也就是二分类交叉熵


这里我们将判别器的损失函数看成两个部分,第一部分就是加号前面的,作为Wi=1,vi = D(Xi),yi=1时的BCE,后面部分作为Wi=1,vi=D(G(Zi)),yi=0时的BCE,所以实际的loss看成是两个部分loss的和

7 训练分类器的算法

def train_discriminator(optimizer,real_data,fake_data):
    N = real_data.size(0)

    optimizer.zero_grad()

    prediction_real = discrimator(real_data)
    error_real = loss(prediction_real,ones_target(N))
    error_real.backward()

    prediction_fake = discrimator(fake_data)
    error_fake = loss(prediction_fake,zeros_target(N))
    error_fake.backward()

    optimizer.step()

    return error_fake+error_real,prediction_real,prediction_fake

8 训练生成器的算法

def train_genreator(optimizer,fake_data):
    N = fake_data.size(0)

    optimizer.zero_grad()
    predicton = discrimator(fake_data)
    error = loss(predicton,ones_target(N))
    error.backward()
    optimizer.step()
    return error

9 开始训练

里面用到Logger类是一个自定义的可视化类,具体代码参见文末,直接复制到同级代码即可

num_test_samples = 16 #用于后面可视化观察训练过程
test_noise = noise(num_test_samples)

num_epoches = 200
logger = Logger(model_name="VGAN",data_name="MNIST")

for epoch in range(num_epoches):
    for n_batch, (real_batch,_) in enumerate(data_loader):
        N = real_batch.size(0)
        real_data = Variable(images_to_vectors(real_batch))
        fake_data = generator(noise(N)).detach()

        d_error,d_pred_real,d_pred_fake = \
            train_discriminator(d_optimizer,real_data,fake_data)

        fake_data = generator(noise(N))
        g_error = train_genreator(g_optimizer,fake_data)
        # Log batch error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        # Display Progress every few batches
        if (n_batch) % 100 == 0:
            test_images = vectors_to_images(generator(test_noise))
            test_images = test_images.data
            logger.log_images(
                test_images, num_test_samples,
                epoch, n_batch, num_batches
            );
            # Display status Logs
            logger.display_status(
                epoch, num_epoches, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )

10 Logger代码,文件名utils.py

import os
import numpy as np
import errno
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
from matplotlib import pyplot as plt
import torch

'''
    TensorBoard Data will be stored in './runs' path
'''


class Logger:

    def __init__(self, model_name, data_name):
        self.model_name = model_name
        self.data_name = data_name

        self.comment = '{}_{}'.format(model_name, data_name)
        self.data_subdir = '{}/{}'.format(model_name, data_name)

        # TensorBoard
        self.writer = SummaryWriter(comment=self.comment)

    def log(self, d_error, g_error, epoch, n_batch, num_batches):

        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()

        step = Logger._step(epoch, n_batch, num_batches)
        self.writer.add_scalar(
            '{}/D_error'.format(self.comment), d_error, step)
        self.writer.add_scalar(
            '{}/G_error'.format(self.comment), g_error, step)

    def log_images(self, images, num_images, epoch, n_batch, num_batches, format='NCHW', normalize=True):
        '''
        input images are expected in format (NCHW)
        '''
        if type(images) == np.ndarray:
            images = torch.from_numpy(images)

        if format == 'NHWC':
            images = images.transpose(1, 3)

        step = Logger._step(epoch, n_batch, num_batches)
        img_name = '{}/images{}'.format(self.comment, '')

        # Make horizontal grid from image tensor
        horizontal_grid = vutils.make_grid(
            images, normalize=normalize, scale_each=True)
        # Make vertical grid from image tensor
        nrows = int(np.sqrt(num_images))
        grid = vutils.make_grid(
            images, nrow=nrows, normalize=True, scale_each=True)

        # Add horizontal images to tensorboard
        self.writer.add_image(img_name, horizontal_grid, step)

        # Save plots
        self.save_torch_images(horizontal_grid, grid, epoch, n_batch)

    def save_torch_images(self, horizontal_grid, grid, epoch, n_batch, plot_horizontal=True):
        out_dir = './data/images/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)

        # Plot and save horizontal
        fig = plt.figure(figsize=(16, 16))
        plt.imshow(np.moveaxis(horizontal_grid.numpy(), 0, -1))
        plt.axis('off')
        if plot_horizontal:
            plt.show(plt.gcf())
        self._save_images(fig, epoch, n_batch, 'hori')
        plt.close()

        # Save squared
        fig = plt.figure()
        plt.imshow(np.moveaxis(grid.numpy(), 0, -1))
        plt.axis('off')
        self._save_images(fig, epoch, n_batch)
        plt.close()

    def _save_images(self, fig, epoch, n_batch, comment=''):
        out_dir = './data/images/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)
        fig.savefig('{}/{}_epoch_{}_batch_{}.png'.format(out_dir,
                                                         comment, epoch, n_batch))

    def display_status(self, epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake):

        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()
        if isinstance(d_pred_real, torch.autograd.Variable):
            d_pred_real = d_pred_real.data
        if isinstance(d_pred_fake, torch.autograd.Variable):
            d_pred_fake = d_pred_fake.data

        print('Epoch: [{}/{}], Batch Num: [{}/{}]'.format(
            epoch, num_epochs, n_batch, num_batches)
        )
        print('Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(d_error, g_error))
        print('D(x): {:.4f}, D(G(z)): {:.4f}'.format(d_pred_real.mean(), d_pred_fake.mean()))

    def save_models(self, generator, discriminator, epoch):
        out_dir = './data/models/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)
        torch.save(generator.state_dict(),
                   '{}/G_epoch_{}'.format(out_dir, epoch))
        torch.save(discriminator.state_dict(),
                   '{}/D_epoch_{}'.format(out_dir, epoch))

    def close(self):
        self.writer.close()

    # Private Functionality

    @staticmethod
    def _step(epoch, n_batch, num_batches):
        return epoch * num_batches + n_batch

    @staticmethod
    def _make_dir(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

11 完整代码

import torch
from torch import nn,optim
import torchvision
from torch.autograd.variable import Variable
from torchvision import transforms,datasets
from utils import Logger

def minist_data():
    compose = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]
    )
    output_dir = './dataset'
    return datasets.MNIST(root=output_dir,train=True,transform=compose,download=True)

#load data
data = minist_data()

#create loader with data,so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data,batch_size=100,shuffle=True)

#num_batches
num_batches = len(data_loader)

class DiscriminatorNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super(DiscriminatorNet,self).__init__()
        n_features = 784*3
        n_out = 1

        self.hidden0 = nn.Sequential(
            nn.Linear(n_features,1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.hidden1 = nn.Sequential(
            nn.Linear(1024,512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.out = nn.Sequential(
            nn.Linear(512,n_out),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.out(x)
        return x

discrimator = DiscriminatorNet()

def images_to_vectors(images):
    return images.view(images.size(0),784*3)
def vectors_to_images(vectors):
    return vectors.view(vectors.size(0),3,28,28)


class GeneratorNet(torch.nn.Module):
    def __init__(self):
        super(GeneratorNet,self).__init__()
        n_features = 100
        n_out = 784*3

        self.hidden0 = nn.Sequential(
            nn.Linear(n_features,256),
            nn.LeakyReLU(0.2),
        )

        self.hidden1 = nn.Sequential(
            nn.Linear(256,512),
            nn.LeakyReLU(0.2)
        )

        self.hidden2 = nn.Sequential(
            nn.Linear(512,1024),
            nn.LeakyReLU(0.2)
        )

        self.out = nn.Sequential(
            nn.Linear(1024,n_out),
            nn.Tanh()
        )

    def forward(self,x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

generator = GeneratorNet()

def noise(size):
    n= Variable(torch.randn(size,100))
    return n

d_optimizer = optim.Adam(discrimator.parameters(),lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(),lr=0.0002)

loss = nn.BCELoss()

def ones_target(size):
    data = Variable(torch.ones(size,1))
    return data

def zeros_target(size):
    data = Variable(torch.zeros(size,1))
    return data


def train_discriminator(optimizer,real_data,fake_data):
    N = real_data.size(0)

    optimizer.zero_grad()

    prediction_real = discrimator(real_data)
    error_real = loss(prediction_real,ones_target(N))
    error_real.backward()

    prediction_fake = discrimator(fake_data)
    error_fake = loss(prediction_fake,zeros_target(N))
    error_fake.backward()

    optimizer.step()

    return error_fake+error_real,prediction_real,prediction_fake


def train_genreator(optimizer,fake_data):
    N = fake_data.size(0)

    optimizer.zero_grad()
    predicton = discrimator(fake_data)
    error = loss(predicton,ones_target(N))
    error.backward()
    optimizer.step()
    return error

num_test_samples = 16
test_noise = noise(num_test_samples)

num_epoches = 200
logger = Logger(model_name="VGAN",data_name="MNIST")

for epoch in range(num_epoches):
    for n_batch, (real_batch,_) in enumerate(data_loader):
        N = real_batch.size(0)
        real_data = Variable(images_to_vectors(real_batch))
        fake_data = generator(noise(N)).detach()

        d_error,d_pred_real,d_pred_fake = \
            train_discriminator(d_optimizer,real_data,fake_data)

        fake_data = generator(noise(N))
        g_error = train_genreator(g_optimizer,fake_data)
        # Log batch error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        # Display Progress every few batches
        if (n_batch) % 100 == 0:
            test_images = vectors_to_images(generator(test_noise))
            test_images = test_images.data
            logger.log_images(
                test_images, num_test_samples,
                epoch, n_batch, num_batches
            );
            # Display status Logs
            logger.display_status(
                epoch, num_epoches, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )






11 训练过程可视化图

最开始


之后,不断改进


最后的结果


生成器错误率随时间变换图


判别器错误率随时间变换图


可以看到,最后判别器的错误率较高,生成器错误率较低。