生成对抗网络(GAN, )是由 Ian 等人在 2014 年提出的一种深度学习模型。它采用对抗思想,通过两个网络的博弈,生成逼真的数据,广泛应用于图像生成、数据增强、文本生成等领域。
核心思想
GAN 的核心思想是通过两个网络的对抗学习,达到生成数据与真实数据难以区分的目标。
生成器() :
判别器() :
工作流程
GAN 的训练过程可以看作一个“博弈”过程:
生成器生成假样本试图欺骗判别器。判别器学习识别真假数据,提高判断能力。两个网络交替训练,最终达到纳什均衡:生成器生成的数据分布与真实数据分布非常接近,判别器无法区分真假数据(判别准确率约为 50%)。
特点与优势
优点:
挑战:
应用场景
图像生成:
数据增强:
风格转换:
文本生成与翻译:
视频生成:
进阶与变体
DCGAN(Deep GAN) :
WGAN( GAN) :
:
:
:
案例:生成 MNIST 手写数字
使用生成对抗网络 (GAN) 生成 MNIST 手写数字是一种经典的案例,能够很好地展示 GAN 的核心机制和实现方法。以下是详细的解释与分解。
1. 数据集简介:MNIST
2. GAN 的组成
GAN 由两个部分组成,分别是生成器()和判别器()。它们通过相互竞争的方式训练,最终生成逼真的手写数字。
3. GAN 的工作流程
初始化:
训练过程: GAN 的训练分为两步:
训练生成器:
迭代训练:
4. 代码解读
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt
(1) 加载 MNIST 数据集
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0 # 归一化
(2) 生成器模型
def build_generator(latent_dim):
model = Sequential([
Dense(128, activation=LeakyReLU(0.2), input_dim=latent_dim),
Dense(784, activation='sigmoid'),
Reshape((28, 28))
])
return model
(3) 判别器模型
def build_discriminator():
model = Sequential([

Flatten(input_shape=(28, 28)),
Dense(128, activation=LeakyReLU(0.2)),
Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
(4) GAN 模型
def build_gan(generator, discriminator):
discriminator.trainable = False # 固定判别器权重
model = Sequential([generator, discriminator])
model.compile(optimizer='adam', loss='binary_crossentropy')
return model
(5) 训练过程
for epoch in range(epochs):
idx = np.random.randint(0, x_train.shape[0], batch_size)
real_images = x_train[idx]
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_images = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
if epoch % 1000 == 0:
print(f"Epoch {epoch}: D Loss Real: {d_loss_real[0]}, D Loss Fake: {d_loss_fake[0]}, G Loss: {g_loss}")
(6) 可视化生成结果
noise = np.random.normal(0, 1, (10, latent_dim))
generated_images = generator.predict(noise)
for i in range(10):
plt.subplot(1, 10, i + 1)
plt.imshow(generated_images[i], cmap='gray')
plt.axis('off')
plt.show()
5. 训练与结果分析
损失变化:
生成效果:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt
# 生成器模型
def build_generator(latent_dim):
model = Sequential([
Dense(128, activation=LeakyReLU(0.2), input_dim=latent_dim),
Dense(784, activation='sigmoid'),
Reshape((28, 28))
])
return model
# 判别器模型

def build_discriminator():
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation=LeakyReLU(0.2)),
Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# GAN 模型
def build_gan(generator, discriminator):
discriminator.trainable = False
model = Sequential([generator, discriminator])
model.compile(optimizer='adam', loss='binary_crossentropy')
return model
# 加载数据
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0
# 参数
latent_dim = 100
batch_size = 128
epochs = 10000
generator = build_generator(latent_dim)
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)
# 训练
for epoch in range(epochs):
# 随机选择真实样本
idx = np.random.randint(0, x_train.shape[0], batch_size)
real_images = x_train[idx]
# 生成假样本
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_images = generator.predict(noise)
# 训练判别器
d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
# 训练生成器
g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
# 打印损失
if epoch % 1000 == 0:
print(f"Epoch {epoch}: D Loss Real: {d_loss_real[0]}, D Loss Fake: {d_loss_fake[0]}, G Loss: {g_loss}")
# 生成图像
noise = np.random.normal(0, 1, (10, latent_dim))
generated_images = generator.predict(noise)
for i in range(10):
plt.subplot(1, 10, i + 1)
plt.imshow(generated_images[i], cmap='gray')
plt.axis('off')
plt.show()
323AI导航网发布
© 版权声明
文章版权归作者所有,未经允许请勿转载。
相关文章
暂无评论...