上一次修改时间:2018-09-01 01:50:19

GAN的简单实现-MNIST数据集

说明

使用简单的二层神经网络以及MNIST数据集以及Tensorflow来实现GAN

python的版本为:3.5

tensorflow的版本为:1.10.0

1.加载数据集

import tensorflow as tf
import numpy as np
import datetime
import matplotlib.pyplot as plt
import os
import math
%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../data/mnist")
#显示数据集的一张图片
sample_image = mnist.train.next_batch(1)[0]
print(sample_image.shape)

sample_image = sample_image.reshape([28, 28])
#plt.imshow(sample_image, cmap='Greys')#显示灰度图
plt.imshow(sample_image)

image.png


2.生成器网络

生成器网络结构说明:生成器网络总共有两层(输入层不算在内);

输入层有100个神经元,中间层有128个神经元,输出层为784个神经元;

各层网络之间采用全连接的方式,因此权值的总量为100*128 + 128*784=113152,

偏置值的总量为128+784=912,参数总量为113152 + 912 = 114064个;

def xavier_init(n_input,n_output,constant=1):
    """
    Xavier初始化器 让权重被初始化调整合理的分布 mean=0 std=2/(n_input+n_output)
    :param n_input:输入节点数量
    :param n_output:输出节点数量
    """
    low=-constant * np.sqrt(6.0/(n_input+n_output))
    high=constant * np.sqrt(6.0/(n_input+n_output))
    return tf.random_uniform((n_input,n_output), minval=low, maxval=high, 
                             dtype=tf.float32)

# 为生成器生成随机噪声
Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z')

# 生成器参数设置
G_W1 = tf.Variable(xavier_init(100, 128), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')
G_W2 = tf.Variable(xavier_init(128, 784), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784]), name='G_b2')
theta_G = [G_W1, G_W2, G_b1, G_b2]#需要通过训练数据确定的四个参数向量

# 生成器网络
#这里的z是G(z)的先验。通过这种方式可以学习到先验空间和pdata(真实数据分布)空间之间的映射
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    
    return G_prob

3.判别器网络

判别器网络结构说明:判别器网络的输入为真实数据或生成器生成的数据;

网络层数也为两层,不算输入层,但输出层计算在内;连接方式也为全连接

输入层为784个神经元,中间层为128个神经元,输出层为1个神经元;

权值的总量为784*128 + 128*1 = 100480个,偏置值为128 + 1 = 129个;

参数总量为:100480 + 129 = 100609个;

# 为判别器准备的MNIST图像输入设置
X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
# 判别器参数设置
D_W1 = tf.Variable(xavier_init(784, 128), name='D_W1')
D_b1 = tf.Variable(tf.zeros(shape=[128]), name='D_b1')
D_W2 = tf.Variable(xavier_init(128, 1), name='D_W2')
D_b2 = tf.Variable(tf.zeros(shape=[1]), name='D_b2')
theta_D = [D_W1, D_W2, D_b1, D_b2]

# 判别器网络
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2    
    D_prob = tf.nn.sigmoid(D_logit)
    
    return D_prob, D_logit

4.GAN训练

生成器生成的数据和真实数据经过判别器网络后会得到一个0-1之间的数,

损失函数则是用该数来定义的;

mb_size = 100
Z_dim = 100

G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

'''
TensorFlow中的优化器只能做最小化,因此,为了最大化损失函数,
我们在上面的代码中给损失加上了一个负号。与此同时,根据论文的伪代码算法,
我们最好最大化tf.reduce_mean(tf.log(D_fake))而不是最小化tf.reduce_mean(1-tf.log(D_fake))。
'''
# GAN原始论文中的损失函数
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

#损失函数训练
# 仅更新D(X)的参数,var_list=theta_D
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)

# 仅更新G(X)的参数,var_list=theta_G
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

def sample_Z(m, n): 
    '''Uniform prior for G(Z)'''
    
    return np.random.uniform(-1., 1., size=[m, n])

'''
我们以随机的噪声开始进行训练,G(Z)不断向pdata趋近。
可以通过观察G(Z)生成的样本和原始MNIST图像的区别来证明这件事。
'''
sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

iterationNum = 100001
for it in range(iterationNum):
    X_mb, _ = mnist.train.next_batch(mb_size)    
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})    
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
    
    #生成迭代中间结果100个
    if it % (math.floor(iterationNum / 100)) == 0:
        print("迭代次数为:", it)
        z_batch = np.random.uniform(-1., 1., size=[1, Z_dim])
        generated_images = generator(Z)
        images = sess.run(generated_images, {Z: z_batch})
        plt.imshow(images[0].reshape([28, 28]))
        plt.show()
        
    #保存模型参数
    if it % (math.floor(iterationNum / 2)) == 0:
        checkpoint_path = os.path.join('./data', 'MINIST_GAN_model.ckpt')
        saver.save(sess, checkpoint_path, global_step = it+1)

以下为训练过程中的部分中间结果:

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

5.用训练好的模型生成图片

# 每行显示5张,总共生成5列
rows = 5
cols = 5
fig1 , ax1 = plt.subplots(rows ,cols ,figsize=(13 , 12))
 
# 标签字体
#fontdict =  {'fontsize': 20,'fontweight' : 6,'verticalalignment': 'baseline','horizontalalignment': 'center'}
#reshape函数中的28*28=784,为本数据的维度
for j in range(rows):
    for i in range(cols):
        z_batch = np.random.uniform(-1., 1., size=[1, Z_dim])
        generated_images = generator(Z)
        images = sess.run(generated_images, {Z: z_batch})
        ax1[j][i].imshow(images[0].reshape([28, 28]))
        ax1[j][i].axis('off')

image.png

6.使用保存的模型生成图片

上面的模型保存方法只保存了模型的参数,使用时需要重建网络的结构;

但这种只保存的参数的方法,可以随时载入保存的参数,并在此参数的基础上继续训练;

#生成器网络
def xavier_init(n_input,n_output,constant=1):
    """
    Xavier初始化器 让权重被初始化调整合理的分布 mean=0 std=2/(n_input+n_output)
    :param n_input:输入节点数量
    :param n_output:输出节点数量
    """
    low=-constant * np.sqrt(6.0/(n_input+n_output))
    high=constant * np.sqrt(6.0/(n_input+n_output))
    return tf.random_uniform((n_input,n_output), minval=low, maxval=high, 
                             dtype=tf.float32)

# 为生成器生成随机噪声
Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z')

# 生成器参数设置
G_W1 = tf.Variable(xavier_init(100, 128), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')
G_W2 = tf.Variable(xavier_init(128, 784), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784]), name='G_b2')
theta_G = [G_W1, G_W2, G_b1, G_b2]#需要通过训练数据确定的四个参数向量

# 生成器网络
#这里的z是G(z)的先验。通过这种方式可以学习到先验空间和pdata(真实数据分布)空间之间的映射
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    
    return G_prob



#判别器网络
# 为判别器准备的MNIST图像输入设置
X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
# 判别器参数设置
D_W1 = tf.Variable(xavier_init(784, 128), name='D_W1')
D_b1 = tf.Variable(tf.zeros(shape=[128]), name='D_b1')
D_W2 = tf.Variable(xavier_init(128, 1), name='D_W2')
D_b2 = tf.Variable(tf.zeros(shape=[1]), name='D_b2')
theta_D = [D_W1, D_W2, D_b1, D_b2]

# 判别器网络
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2    
    D_prob = tf.nn.sigmoid(D_logit)
    
    return D_prob, D_logit



#训练相关
mb_size = 100
Z_dim = 100

G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

'''
TensorFlow中的优化器只能做最小化,因此,为了最大化损失函数,
我们在上面的代码中给损失加上了一个负号。与此同时,根据论文的伪代码算法,
我们最好最大化tf.reduce_mean(tf.log(D_fake))而不是最小化tf.reduce_mean(1-tf.log(D_fake))。
'''
# GAN原始论文中的损失函数
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

#损失函数训练
# 仅更新D(X)的参数,var_list=theta_D
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)

# 仅更新G(X)的参数,var_list=theta_G
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

def sample_Z(m, n): 
    '''Uniform prior for G(Z)'''
    
    return np.random.uniform(-1., 1., size=[m, n])

'''
我们以随机的噪声开始进行训练,G(Z)不断向pdata趋近。
可以通过观察G(Z)生成的样本和原始MNIST图像的区别来证明这件事。
'''
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#saver = tf.train.Saver()

#正式载入保存的模型参数
saver = tf.train.import_meta_graph('./data/MINIST_GAN_model.ckpt-100001.meta')
saver.restore(sess, tf.train.latest_checkpoint('./data/'))

#用保存的模型恢复参数
print(D_W1)
graph = tf.get_default_graph()
D_W1 = graph.get_tensor_by_name("D_W1:0")
D_W2 = graph.get_tensor_by_name("D_W2:0")
D_b1 = graph.get_tensor_by_name("D_b1:0")
D_b2 = graph.get_tensor_by_name("D_b2:0")
print(D_W1)

G_W1 = graph.get_tensor_by_name("G_W1:0")
G_W2 = graph.get_tensor_by_name("G_W2:0")
G_b1 = graph.get_tensor_by_name("G_b1:0")
G_b2 = graph.get_tensor_by_name("G_b2:0")



# 每行显示5张,总共生成5列,生成图片
rows = 5
cols = 5
fig1 , ax1 = plt.subplots(rows ,cols ,figsize=(13 , 12))
 
# 标签字体
#fontdict =  {'fontsize': 20,'fontweight' : 6,'verticalalignment': 'baseline','horizontalalignment': 'center'}
#reshape函数中的28*28=784,为本数据的维度
for j in range(rows):
    for i in range(cols):
        z_batch = np.random.uniform(-1., 1., size=[1, Z_dim])
        generated_images = generator(Z)
        images = sess.run(generated_images, {Z: z_batch})
        ax1[j][i].imshow(images[0].reshape([28, 28]))
        ax1[j][i].axis('off')

image.png

image.png