菜单

生成模型

相关源文件

简介与概述

本文档介绍了labml_nn仓库中的生成模型实现,概述了生成与训练样本相似的新数据的不同方法。生成模型学习数据的底层分布,使其能够合成新颖的样本,这与只学习类别边界的判别模型不同。

该仓库实现了四种主要类型的生成模型

  1. 扩散模型 - 特别是去噪扩散概率模型(DDPM)
  2. 生成对抗网络(GANs) - 各种架构和改进
  3. SketchRNN - 专门用于素描绘制生成
  4. 胶囊网络 - 一种不同的特征表示方法

有关微调技术,请参阅微调技术

来源: labml_nn/diffusion/ddpm/__init__.py docs/gan/index.html labml_nn/sketch_rnn/__init__.py labml_nn/capsule_networks/__init__.py

扩散模型

扩散模型是一类生成模型,其工作原理是逐步向数据添加噪声,然后学习如何逆转此过程。该仓库实现了基于论文《去噪扩散概率模型》的去噪扩散概率模型(DDPM)。

DDPM算法

DDPM过程包括

  1. 前向过程:在T个时间步长内逐步向数据添加噪声
  2. 逆向过程:学习预测并去除这些噪声以生成新的样本
  3. 训练:使用简化损失优化噪声预测网络

来源:labml_nn/diffusion/ddpm/__init__.py172-287 labml_nn/diffusion/ddpm/experiment.py34-159

关键组件

DDPM实现包括

  1. DenoiseDiffusion类 - 管理前向和逆向过程的核心算法

    • q_xt_x0() - 获取时间 t 处噪声数据的分布
    • q_sample() - 通过添加噪声采样噪声数据
    • p_sample() - 通过预测噪声从逆向过程采样
    • loss() - 计算训练的简化损失
  2. UNet模型 - 用于噪声预测的神经网络架构

    • 时间嵌入以条件化扩散步骤
    • 带有残差连接的下/上采样路径
    • 特定分辨率下的自注意力层
  3. 训练配置

    • 支持MNIST和CelebA数据集
    • 可配置的超参数,例如步数、学习率
    • 训练期间生成样本的可视化

来源:labml_nn/diffusion/ddpm/__init__.py172-287 labml_nn/diffusion/ddpm/experiment.py80-156

生成对抗网络(GANs)

该仓库实现了几种GAN变体,每种都在原始对抗训练方法的基础上有所改进。

GAN架构

GANs由两个相互竞争的网络组成

  1. 生成器 - 从随机噪声中创建假样本
  2. 判别器 - 尝试区分真实样本和假样本

这些网络在对抗过程中进行训练,生成器在创建逼真样本方面的能力会随着判别器在检测方面的改进而提高。

来源:docs/gan/index.html74-79

已实现的GAN变体

该仓库包括以下GAN架构

  1. 原始GAN - 带有MLP的基础实现
  2. DCGAN - 使用卷积网络进行更好的图像生成
  3. CycleGAN - 用于带有循环一致性的非配对图像到图像转换
  4. Wasserstein GAN - 使用Wasserstein距离实现更稳定的训练
  5. 带有梯度惩罚的WGAN - 通过梯度惩罚改进WGAN
  6. StyleGAN 2 - 使用基于风格的生成器进行高分辨率图像生成

每个实现都包含详细的注释,解释了其架构和训练方法。

来源:docs/gan/index.html74-79

SketchRNN

SketchRNN是论文《A Neural Representation of Sketch Drawings》的实现。它是一个序列到序列的变分自编码器,通过预测一系列笔画来学习生成素描图。

架构

该模型由以下部分组成

  1. 编码器 - 处理输入笔画的双向LSTM
  2. 潜在空间 - VAE采样以创建潜在表示
  3. 解码器 - 将笔画预测为二元高斯混合的LSTM解码器

来源:labml_nn/sketch_rnn/__init__.py50-126 labml_nn/sketch_rnn/__init__.py197-238 labml_nn/sketch_rnn/__init__.py241-307

实现细节

SketchRNN实现包括

  1. StrokesDataset - 加载并预处理基于笔画的图画

    • 处理(Δx, Δy, pen_state)形式的序列数据
    • 缩放并标准化笔画数据
  2. EncoderRNN - 处理输入笔画

    • 双向LSTM编码器
    • 为潜在空间生成均值 (μ) 和对数方差 (σ̂)
  3. DecoderRNN - 生成新笔画

    • 以潜在向量z为条件的LSTM解码器
    • 输出笔画坐标的高斯混合分布
    • 预测笔画状态概率(笔下、笔上、序列结束)
  4. BivariateGaussianMixture - 对笔画坐标建模

    • 二元高斯混合分布
    • 采样时的温度调整
  5. 损失函数

    • 笔画坐标和笔画状态的重构损失
    • 潜在空间正则化的KL散度损失
  6. 采样器 - 从模型生成新素描

来源:labml_nn/sketch_rnn/__init__.py50-126 labml_nn/sketch_rnn/__init__.py129-178 labml_nn/sketch_rnn/__init__.py197-307 labml_nn/sketch_rnn/__init__.py312-361 labml_nn/sketch_rnn/__init__.py370-455

胶囊网络

胶囊网络代表了一种不同的特征表示方法,它使用神经元向量(胶囊)而不是标量神经元。虽然它们在传统意义上并非严格的生成模型,但它们包含一个作为正则化器使用的重建网络。

架构

胶囊网络的实现包括

  1. 挤压函数 - 保留向量方向的非线性激活函数
  2. 路由 - 胶囊层之间的动态路由算法
  3. 边际损失 - 胶囊分类的损失函数
  4. 重建网络 - 从胶囊输出重建输入的解码器

来源:labml_nn/capsule_networks/__init__.py39-71 labml_nn/capsule_networks/__init__.py73-134 labml_nn/capsule_networks/__init__.py136-186 labml_nn/capsule_networks/mnist.py29-97

关键组件

  1. 挤压函数 - 在保留向量方向的同时归一化向量长度

    v_j = (||s_j||² / (1 + ||s_j||²)) * (s_j / ||s_j||)
    
  2. 路由 - 实现胶囊间的动态路由

    • 计算初始预测 û_j|i = W_ij u_i
    • 迭代更新耦合系数 c_ij
    • 计算加权和 s_j = Σ_i c_ij û_j|i
    • 应用挤压函数得到 v_j
  3. 边际损失 - 分类损失函数

    L_k = T_k max(0, m⁺ - ||v_k||)² + λ(1-T_k)max(0, ||v_k|| - m⁻)²
    
  4. MNIST模型 - MNIST数字分类的完整实现

    • 用于提取初始特征的卷积层
    • 从卷积特征创建主胶囊
    • 通过动态路由连接的数字胶囊
    • 用于正则化的重建网络

来源:labml_nn/capsule_networks/__init__.py39-71 labml_nn/capsule_networks/__init__.py73-134 labml_nn/capsule_networks/__init__.py136-186 labml_nn/capsule_networks/mnist.py29-97

训练基础设施

每个生成模型实现都由仓库的训练基础设施支持,其中包括

  1. 配置类 - 实验的配置管理

    • 超参数设置
    • 数据集加载和预处理
    • 模型初始化
    • 训练循环管理
  2. 损失函数 - 模型特定的损失函数

    • 重建损失
    • KL散度损失
    • 对抗损失
  3. 可视化 - 用于可视化生成样本的工具

    • 训练期间的图像生成
    • 样本绘图和可视化
    • 潜在空间插值
  4. 实验跟踪 - 与labml集成以进行实验跟踪

    • 损失跟踪
    • 样本可视化
    • 模型检查点

来源:labml_nn/sketch_rnn/__init__.py458-611 labml_nn/diffusion/ddpm/experiment.py34-159 labml_nn/capsule_networks/mnist.py100-159

总结

该仓库中的生成模型代表了数据生成的多种方法

  • 扩散模型(DDPM) - 渐进式去噪方法
  • GANs - 生成器和判别器之间的对抗训练
  • SketchRNN - 用于素描生成的基于序列的VAE
  • 胶囊网络 - 具有重建能力的替代神经网络架构

每个实现都提供了详细的注释,解释了理论、架构和训练过程,使其成为理解和扩展生成模型的宝贵资源。