本文档介绍了 labml_nn 仓库中生成对抗网络(GAN)的实现。GAN 是一类深度学习框架,其中两个神经网络——生成器和判别器——在一个最小最大博弈中相互竞争。本页面提供了 GAN 架构的技术概述,以及仓库中各种 GAN 实现,包括原始 GAN、DCGAN、Cycle GAN、Wasserstein GAN、带梯度惩罚的 Wasserstein GAN 和 StyleGAN 2。
有关扩散模型等其他生成模型的信息,请参阅扩散模型,有关 SketchRNN 的信息,请参阅SketchRNN。
来源: docs/gan/index.html
来源: docs/gan/index.html
该仓库包含几种关键 GAN 变体的实现,每种变体都解决了原始 GAN 公式中的不同挑战
来源: docs/gan/index.html
原始 GAN 实现采用对抗训练范式,其中生成器和判别器同时进行训练。生成器试图生成逼真的样本以欺骗判别器,而判别器则试图区分真实样本和生成样本。
主要功能
DCGAN 通过为生成器和判别器使用深度卷积神经网络来扩展原始 GAN,使其更适用于图像生成任务。
主要功能
Cycle GAN 实现了跨域的非配对图像到图像转换(例如,马到斑马),无需配对训练样本。
主要功能
Wasserstein GAN (WGAN) 通过使用 Wasserstein 距离(地球移动距离)而非原始 GAN 中使用的 Jensen-Shannon 散度来提高训练稳定性。
主要功能
带梯度惩罚的 WGAN (WGAN-GP) 通过用梯度惩罚代替权重裁剪来进一步改进 WGAN,以强制执行 Lipschitz 约束。
主要功能
StyleGAN 2 是一种先进的 GAN 架构,用于生成具有可控风格的高质量图像。
主要功能
来源: docs/gan/index.html
仓库中的所有 GAN 实现都共享某些通用元素
| 组件 | 目的 | 常见实现 |
|---|---|---|
| 生成器 (Generator) | 从噪声中创建样本 | 带上采样层的神经网络 |
| 判别器 | 区分真实与伪造 | 带下采样层的神经网络 |
| 损失函数 | 指导对抗训练 | 交叉熵或 Wasserstein 距离 |
| 训练循环 | 交替训练步骤 | 先 D,后 G,使用不同的损失函数 |
| 样本可视化 | 监控生成质量 | 训练期间生成的样本网格 |
每个 GAN 实现都对标准架构有特定的修改
仓库中的 GAN 可以与其他模块结合使用
来源: docs/gan/index.html
本仓库中的 GAN 实现提供了来自研究文献的关键 GAN 架构的带注释、可读的实现。这些实现可作为教育资源,帮助理解 GAN 的工作原理以及它们如何随时间演变,以解决训练稳定性、样本质量和特定应用(如非配对图像转换)中的挑战。