pix2pix-GAN

724 阅读3分钟

图片转卡通

import numpy as np
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt
import warnings
import glob
warnings.filterwarnings("ignore")
%matplotlib inline
H:\Anaconda\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
print("TensorFlow Version: {}".format(tf.__version__))
TensorFlow Version: 1.8.0

构建模型

  • inputs
  • generator
  • discriminator
  • loss
  • optimizer
  • train

Generator

def generator(inputs_real, is_train=True, alpha=0.01):
    # 256*256*3
    with tf.variable_scope("generator", reuse=(not is_train)):
        # 128*128*64
        conv1 = tf.layers.conv2d(inputs_real, 64, (3,3), padding='same')
        conv1 = tf.nn.relu(conv1)
        conv1 = tf.layers.max_pooling2d(conv1, (2,2), (2,2), padding='same')
        # 64*64*128
        conv2 = tf.layers.conv2d(conv1, 128, (3,3), padding='same')
        conv2 = tf.nn.relu(conv2)
        conv2 = tf.layers.max_pooling2d(conv2, (2,2), (2,2), padding='same')
        # 32*32*256
        conv3 = tf.layers.conv2d(conv2, 256, (3,3), padding='same')
        conv3 = tf.nn.relu(conv3)
        conv3 = tf.layers.max_pooling2d(conv3, (2,2), (2,2), padding='same')
        # 16*16*512
        conv4 = tf.layers.conv2d(conv3, 512, (3,3), padding='same')
        conv4 = tf.nn.relu(conv4)
        conv4 = tf.layers.max_pooling2d(conv4, (2,2), (2,2), padding='same')
        # 8*8*512
        conv5 = tf.layers.conv2d(conv4, 512, (3,3), padding='same')
        conv5 = tf.nn.relu(conv5)
        conv5 = tf.layers.max_pooling2d(conv5, (2,2), (2,2), padding='same')
        # 4*4*512
        conv6 = tf.layers.conv2d(conv5, 512, (3,3), padding='same')
        conv6 = tf.nn.relu(conv6)
        conv6 = tf.layers.max_pooling2d(conv6, (2,2), (2,2), padding='same')
        # 2*2*512
        conv7 = tf.layers.conv2d(conv6, 512, (3,3), padding='same')
        conv7 = tf.nn.relu(conv7)
        conv7 = tf.layers.max_pooling2d(conv7, (2,2), (2,2), padding='same')
        # 1*1*512
        conv8 = tf.layers.conv2d(conv7, 512, (3,3), padding='same')
        conv8 = tf.nn.relu(conv8)
        conv8 = tf.layers.max_pooling2d(conv8, (2,2), (2,2), padding='same')
        
        
        # 2*2*512
        conv9 = tf.layers.conv2d_transpose(conv8, 512, 3, strides=2, padding='same')
        conv9 = tf.layers.batch_normalization(conv9, training=is_train)
        conv9 = tf.nn.relu(conv9)
        conv9 = tf.nn.dropout(conv9, keep_prob=0.5)
        # 4*4*512
        conv10 = tf.concat([conv9,conv7], 3)
        conv10 = tf.layers.conv2d_transpose(conv10, 512, 3, strides=2, padding='same')
        conv10 = tf.layers.batch_normalization(conv10, training=is_train)
        conv10 = tf.nn.relu(conv10)
        conv10 = tf.nn.dropout(conv10, keep_prob=0.5)
        # 8*8*512
        conv11 = tf.concat([conv10,conv6], 3)
        conv11 = tf.layers.conv2d_transpose(conv11, 512, 3, strides=2, padding='same')
        conv11 = tf.layers.batch_normalization(conv11, training=is_train)
        conv11 = tf.nn.relu(conv11)
        conv11 = tf.nn.dropout(conv11, keep_prob=0.5)
        # 16*16*512
        conv12 = tf.concat([conv11,conv5], 3)
        conv12 = tf.layers.conv2d_transpose(conv12, 512, 3, strides=2, padding='same')
        conv12 = tf.layers.batch_normalization(conv12, training=is_train)
        conv12 = tf.nn.relu(conv12)
        # 32*32*256
        conv13 = tf.concat([conv12,conv4], 3)
        conv13 = tf.layers.conv2d_transpose(conv13, 256, 3, strides=2, padding='same')
        conv13 = tf.layers.batch_normalization(conv13, training=is_train)
        conv13 = tf.nn.relu(conv13)
        # 64*64*128
        conv14 = tf.concat([conv13,conv3], 3)
        conv14 = tf.layers.conv2d_transpose(conv14, 128, 3, strides=2, padding='same')
        conv14 = tf.layers.batch_normalization(conv14, training=is_train)
        conv14 = tf.nn.relu(conv14)
        # 128*128*64
        conv15 = tf.concat([conv14,conv2], 3)
        conv15 = tf.layers.conv2d_transpose(conv15, 64, 3, strides=2, padding='same')
        conv15 = tf.layers.batch_normalization(conv15, training=is_train)
        conv15 = tf.nn.relu(conv15)
        # 256*256*3
        conv16 = tf.concat([conv15,conv1], 3)
        conv16 = tf.layers.conv2d_transpose(conv16, 3, 3, strides=2, padding='same')
        conv16 = tf.layers.batch_normalization(conv16, training=is_train)
    
        # 图片归一化
        outputs = tf.nn.tanh(conv16)
        
        return outputs

Discriminator

def discriminator(inputs_real, inputs_cartoon, reuse=False, alpha=0.01):
    
    with tf.variable_scope("discriminator", reuse=reuse):
        
        layer0 = tf.concat([inputs_real, inputs_cartoon], 3)

        layer1 = tf.layers.conv2d(layer0, 64, 3, strides=2, padding='same')
        layer1 = tf.maximum(alpha * layer1, layer1)
        
        layer2 = tf.layers.conv2d(layer1, 128, 3, strides=2, padding='same')
        layer2 = tf.layers.batch_normalization(layer2, training=True)
        layer2 = tf.maximum(alpha * layer2, layer2)
        
        layer3 = tf.layers.conv2d(layer2, 256, 3, strides=2, padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=True)
        layer3 = tf.maximum(alpha * layer3, layer3)
        
        layer4 = tf.layers.conv2d(layer3, 512, 3, strides=2, padding='same')
        layer4 = tf.layers.batch_normalization(layer4, training=True)
        layer4 = tf.maximum(alpha * layer4, layer4)
        
        flatten = tf.reshape(layer4, (-1, 16*16*512))
        logits = tf.layers.dense(flatten, 1)
        outputs = tf.sigmoid(logits)
        
        return logits, outputs

Loss

def get_loss(inputs_images, inputs_cartoons, smooth=0.1):

    g_outputs = generator(inputs_images, is_train=True)
    d_logits_real, d_outputs_real = discriminator(inputs_images, inputs_cartoons)
    d_logits_fake, d_outputs_fake = discriminator(inputs_images, g_outputs, reuse=True)
    
    # 计算Loss
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                         labels=tf.ones_like(d_outputs_real)*(1-smooth)))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_outputs_fake)))
    
    # 计算Loss
    g_loss_gan = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 
                                                                     labels=tf.ones_like(d_outputs_fake)*(1-smooth)))
    
    g_outputs_logits = tf.reshape(g_outputs, [-1, 256*256*3])
    inputs_cartoons_logits = tf.reshape(inputs_cartoons, [-1, 256*256*3])
    g_loss_l1 = tf.reduce_mean(tf.reduce_sum(tf.abs(g_outputs_logits - inputs_cartoons_logits)))
    

    # 计算Loss和
    g_loss = tf.add(g_loss_gan, g_loss_l1)
    d_loss = tf.add(d_loss_real, d_loss_fake)
    
    return g_loss, d_loss

Optimizer

def get_optimizer(g_loss, d_loss, beta1=0.4, learning_rate=0.001):
    
    train_vars = tf.trainable_variables()
    
    g_vars = [var for var in train_vars if var.name.startswith("generator")]
    d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
    
    # Optimizer
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        g_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
        d_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
    
    return g_opt, d_opt

辅助函数,用来在迭代中显示图片

def plot_images(samples):
    samples = (samples + 1) / 2
    fig, axes = plt.subplots(nrows=1, ncols=5, sharex=True, sharey=True, figsize=(10,2))
    for img, ax in zip(samples, axes):
        ax.imshow(img.reshape((250, 200, 3)))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0)
def show_generator_output(sess, samp_images):
    samples = sess.run(generator(samp_images, False))
    samples = sess.run(tf.image.resize_image_with_crop_or_pad(samples, 250, 200))
    return samples

Train

# 定义参数
learning_rate = 0.001
beta1 = 0.4
def train():
    
    # 存储loss
    losses = []
    steps = 300
        
    image_filenames = glob.glob('./training_photos/*.jpg')
    cartoon_filenames = glob.glob('./training__sketches/*.jpg')
    
    image_que = tf.train.slice_input_producer([image_filenames, cartoon_filenames], shuffle=True)
    
    image_ = tf.read_file(image_que[0])
    image = tf.image.decode_jpeg(image_, channels=3)
    image = tf.image.resize_image_with_crop_or_pad(image, 256, 256)
    new_img = tf.image.per_image_standardization(image)
        
    cartoon_ = tf.read_file(image_que[1])
    cartoon = tf.image.decode_jpeg(cartoon_, channels=3)
    cartoon = tf.image.resize_image_with_crop_or_pad(cartoon, 256, 256)
    new_cartoon = tf.image.per_image_standardization(cartoon)
    
    batch_size = 5
    capacity = 3 + 2 * batch_size
          
    image_batch, cartoon_batch = tf.train.batch([new_img, new_cartoon], batch_size=batch_size, capacity=capacity)
    
    g_loss, d_loss = get_loss(image_batch, cartoon_batch)
    g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)
    
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        sess.run(tf.global_variables_initializer())
        
        # 迭代epoch
        
        for e in range(steps):
                # run optimizer
            _ = sess.run(g_train_opt)
            _ = sess.run(d_train_opt)
                
            if e % 50 == 0:
                saver.save(sess,'./less8',global_step = e)
                train_loss_d = d_loss.eval()
                
                train_loss_g = g_loss.eval()
                losses.append((train_loss_d, train_loss_g))
                    # 显示图片
                samples = show_generator_output(sess, image_batch)
                plot_images(samples)
                print("Epoch {}/{}....".format(e+1, steps), 
                      "Discriminator Loss: {:.4f}....".format(train_loss_d),
                      "Generator Loss: {:.4f}....". format(train_loss_g))
        saver.save(sess,'./less8',global_step = steps)
        coord.request_stop()
        coord.join(threads)                  
with tf.Graph().as_default():
    train()
Epoch 1/300.... Discriminator Loss: 1.4783.... Generator Loss: 853222.2500....
Epoch 51/300.... Discriminator Loss: 1.0385.... Generator Loss: 295154.4688....
Epoch 101/300.... Discriminator Loss: 2.4878.... Generator Loss: 266202.3750....
Epoch 151/300.... Discriminator Loss: 2.2916.... Generator Loss: 283186.2188....
Epoch 201/300.... Discriminator Loss: 1.1796.... Generator Loss: 271667.4062....
Epoch 251/300.... Discriminator Loss: 0.8088.... Generator Loss: 262577.5000....

image_filenames = glob.glob('./testing_photos/*.jpg')
cartoon_filenames = glob.glob('./testing_sketches/*.jpg')
    
image_que = tf.train.slice_input_producer([image_filenames, cartoon_filenames], shuffle=True)
    
image_ = tf.read_file(image_que[0])
image = tf.image.decode_jpeg(image_, channels=3)
image = tf.image.resize_image_with_crop_or_pad(image, 256, 256)
new_img = tf.image.per_image_standardization(image)
        
cartoon_ = tf.read_file(image_que[1])
cartoon = tf.image.decode_jpeg(cartoon_, channels=3)
cartoon = tf.image.resize_image_with_crop_or_pad(cartoon, 256, 256)
new_cartoon = tf.image.per_image_standardization(cartoon)
    
batch_size = 5
capacity = 3 + 2 * batch_size
          
image_batch, cartoon_batch = tf.train.batch([new_img, new_cartoon], batch_size=batch_size, capacity=capacity)
    
g_loss, d_loss = get_loss(image_batch, cartoon_batch)
g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)
saver = tf.train.Saver()
sess = tf.Session()
model_file=tf.train.latest_checkpoint('./')
saver.restore(sess, model_file)
INFO:tensorflow:Restoring parameters from ./less8-300
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord,sess=sess)
for i in range(10):
    samples = show_generator_output(sess, image_batch)
    plot_images(samples)
coord.request_stop()
coord.join(threads)