生成对抗网络‌ 开发也能看懂的大模型:GAN

默认分类2天前发布 admin
1,000 0
ChatGPT国内版

生成对抗网络(GAN, )是由 Ian 等人在 2014 年提出的一种深度学习模型。它采用对抗思想,通过两个网络的博弈,生成逼真的数据,广泛应用于图像生成、数据增强、文本生成等领域。

核心思想

GAN 的核心思想是通过两个网络的对抗学习,达到生成数据与真实数据难以区分的目标。

生成器() :

判别器() :

工作流程

GAN 的训练过程可以看作一个“博弈”过程:

生成器生成假样本试图欺骗判别器。判别器学习识别真假数据,提高判断能力。两个网络交替训练,最终达到纳什均衡:生成器生成的数据分布与真实数据分布非常接近,判别器无法区分真假数据(判别准确率约为 50%)。

特点与优势

优点:

挑战:

应用场景

图像生成:

数据增强:

风格转换:

文本生成与翻译:

视频生成:

进阶与变体

DCGAN(Deep GAN) :

WGAN( GAN) :

案例:生成 MNIST 手写数字

使用生成对抗网络 (GAN) 生成 MNIST 手写数字是一种经典的案例,能够很好地展示 GAN 的核心机制和实现方法。以下是详细的解释与分解。

1. 数据集简介:MNIST

2. GAN 的组成

GAN 由两个部分组成,分别是生成器()和判别器()。它们通过相互竞争的方式训练,最终生成逼真的手写数字。

3. GAN 的工作流程

初始化:

训练过程: GAN 的训练分为两步:

训练生成器:

迭代训练:

4. 代码解读

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt

(1) 加载 MNIST 数据集

(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0  # 归一化

(2) 生成器模型

def build_generator(latent_dim):
    model = Sequential([
        Dense(128, activation=LeakyReLU(0.2), input_dim=latent_dim),
        Dense(784, activation='sigmoid'),
        Reshape((28, 28))
    ])
    return model

(3) 判别器模型

def build_discriminator():
    model = Sequential([

生成对抗网络‌ 开发也能看懂的大模型:GAN

Flatten(input_shape=(28, 28)), Dense(128, activation=LeakyReLU(0.2)), Dense(1, activation='sigmoid') ]) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) return model

(4) GAN 模型

def build_gan(generator, discriminator):
    discriminator.trainable = False  # 固定判别器权重
    model = Sequential([generator, discriminator])
    model.compile(optimizer='adam', loss='binary_crossentropy')
    return model

(5) 训练过程

for epoch in range(epochs):
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_images = x_train[idx]
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    fake_images = generator.predict(noise)
    d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
    d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
    g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}: D Loss Real: {d_loss_real[0]}, D Loss Fake: {d_loss_fake[0]}, G Loss: {g_loss}")

(6) 可视化生成结果

noise = np.random.normal(0, 1, (10, latent_dim))
generated_images = generator.predict(noise)
for i in range(10):
    plt.subplot(1, 10, i + 1)
    plt.imshow(generated_images[i], cmap='gray')
    plt.axis('off')
plt.show()

5. 训练与结果分析

损失变化:

生成效果:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt
# 生成器模型
def build_generator(latent_dim):
    model = Sequential([
        Dense(128, activation=LeakyReLU(0.2), input_dim=latent_dim),
        Dense(784, activation='sigmoid'),
        Reshape((28, 28))
    ])
    return model
# 判别器模型

生成对抗网络‌ 开发也能看懂的大模型:GAN

def build_discriminator(): model = Sequential([ Flatten(input_shape=(28, 28)), Dense(128, activation=LeakyReLU(0.2)), Dense(1, activation='sigmoid') ]) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) return model # GAN 模型 def build_gan(generator, discriminator): discriminator.trainable = False model = Sequential([generator, discriminator]) model.compile(optimizer='adam', loss='binary_crossentropy') return model # 加载数据 (x_train, _), (_, _) = tf.keras.datasets.mnist.load_data() x_train = x_train / 255.0 # 参数 latent_dim = 100 batch_size = 128 epochs = 10000 generator = build_generator(latent_dim) discriminator = build_discriminator() gan = build_gan(generator, discriminator) # 训练 for epoch in range(epochs): # 随机选择真实样本 idx = np.random.randint(0, x_train.shape[0], batch_size) real_images = x_train[idx] # 生成假样本 noise = np.random.normal(0, 1, (batch_size, latent_dim)) fake_images = generator.predict(noise) # 训练判别器 d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1))) d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1))) # 训练生成器 g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1))) # 打印损失 if epoch % 1000 == 0: print(f"Epoch {epoch}: D Loss Real: {d_loss_real[0]}, D Loss Fake: {d_loss_fake[0]}, G Loss: {g_loss}") # 生成图像 noise = np.random.normal(0, 1, (10, latent_dim)) generated_images = generator.predict(noise) for i in range(10): plt.subplot(1, 10, i + 1) plt.imshow(generated_images[i], cmap='gray') plt.axis('off') plt.show()

323AI导航网发布

© 版权声明
广告也精彩

相关文章

暂无评论

暂无评论...