Keras中的生成对抗网络(GAN)实践

时尚捕手 2019-05-13 ⋅ 22 阅读

介绍

生成对抗网络(GAN)是一种由两个模型组成的神经网络架构,它们相互竞争并相互改进。GAN由一个生成器模型和一个判别器模型组成。生成器模型用于生成伪造的数据,而判别器模型则用于区分真实数据和伪造数据。GAN模型可以用于许多应用,如图像生成、文本生成等。

在本篇博客中,我们将使用Keras库来实现一个简单的图像生成的GAN模型。我们将使用MNIST数据集(手写数字图像集)作为我们的训练数据。

准备工作

首先,我们需要导入一些必要的库。确保你已经安装了Keras和TensorFlow。

import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LeakyReLU
from keras.optimizers import Adam
from keras.datasets import mnist

接下来,我们需要加载MNIST数据集。MNIST数据集包含了60000个训练样本和10000个测试样本,每个样本都是28x28的灰度图像。

(X_train, _), (_, _) = mnist.load_data()

# 数据预处理
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = X_train.reshape(X_train.shape[0], 784)

我们将每个像素的值映射到[-1,1]的范围内,将数据转换为一维向量的形式。

生成器模型

生成器模型将随机噪声作为输入,并输出生成的图像。我们将使用一个简单的全连接神经网络作为生成器模型。

def build_generator():
    generator = Sequential()
    generator.add(Dense(256, input_dim=100))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(512))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(784, activation='tanh'))
    generator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
    return generator

生成器模型接受一个100维的噪声向量作为输入,并通过一系列全连接层生成输出。我们使用ReLU激活函数来避免梯度消失的问题,并使用tanh激活函数来将输出值映射到[-1,1]范围内。最后,我们使用二元交叉熵作为损失函数,并使用Adam优化器来训练模型。

判别器模型

判别器模型将真实图像和生成图像作为输入,并输出一个概率值来表示输入图像是真实的还是伪造的。同样,我们使用一个简单的全连接神经网络作为判别器模型。

def build_discriminator():
    discriminator = Sequential()
    discriminator.add(Dense(512, input_dim=784))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dense(256))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dense(1, activation='sigmoid'))
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
    return discriminator

判别器模型接受一个784维的图像向量作为输入,并通过一系列全连接层生成输出。我们使用ReLU激活函数来避免梯度消失的问题,并使用sigmoid激活函数来将输出值映射到[0,1]范围内,表示概率。最后,我们使用二元交叉熵作为损失函数,并使用Adam优化器来训练模型。

构建GAN模型

接下来,我们将构建一个整体的GAN模型,将生成器和判别器连接在一起。

def build_gan(generator, discriminator):
    discriminator.trainable = False
    gan = Sequential()
    gan.add(generator)
    gan.add(discriminator)
    gan.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
    return gan

在构建GAN模型之前,我们需要将判别器的trainable属性设置为False,以确保在训练GAN模型时只更新生成器的权重。然后,我们将生成器和判别器连接在一起,并使用二元交叉熵作为损失函数,并使用Adam优化器来训练模型。

训练GAN模型

我们已经准备好构建和训练GAN模型了。

epochs = 100
batch_size = 128

generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)

for epoch in range(epochs):
    for batch in range(X_train.shape[0] // batch_size):
        # 训练判别器
        noise = np.random.normal(0, 1, size=(batch_size, 100))
        fake_images = generator.predict(noise)

        real_images = X_train[batch * batch_size: (batch + 1) * batch_size]

        X = np.concatenate([real_images, fake_images])
        y = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])

        discriminator_loss = discriminator.train_on_batch(X, y)

        # 训练生成器
        noise = np.random.normal(0, 1, size=(batch_size, 100))
        y = np.ones((batch_size, 1))

        generator_loss = gan.train_on_batch(noise, y)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}/{epochs}: Discriminator Loss: {discriminator_loss}, Generator Loss: {generator_loss}")

在每个epoch中,我们将实际图像和生成图像送入判别器进行训练。首先,我们生成随机噪声样本,并通过生成器模型生成伪造图像。然后,我们从训练数据中随机选择一批真实图像。我们将真实图像和伪造图像合并在一起,并创建相应的标签,以区分真实图像和伪造图像。然后,我们使用这些训练样本进行判别器的训练。接下来,我们生成新的随机噪声样本,并将它们的标签设置为1,表示它们是真实的图像。最后,我们使用这些样本进行生成器的训练。

结果展示

训练完模型后,我们可以使用生成器模型生成一些新的图像。


def generate_images(generator, num_images):
    noise = np.random.normal(0, 1, size=(num_images, 100))
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(generated_images.shape[0], 28, 28)
    generated_images = 0.5 * generated_images + 0.5

    fig, axes = plt.subplots(1, num_images, figsize=(20, 4))
    for i, image in enumerate(generated_images):
        axes[i].imshow(image, cmap='gray')
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

generate_images(generator, 10)

通过给生成器输入一些随机噪声,我们可以生成一组新的手写数字图像。

结论

在本篇博客中,我们使用Keras库实现了一个简单的生成对抗网络(GAN)模型。我们使用MNIST数据集训练模型,并成功生成了一组新的手写数字图像。GAN模型是一个强大的工具,可以应用于许多领域,如图像生成、文本生成等。希望本篇博客对你理解GAN模型的实践应用有所帮助。


全部评论: 0

    我有话说: