本文档介绍了labml_nn仓库中的生成模型实现,概述了生成与训练样本相似的新数据的不同方法。生成模型学习数据的底层分布,使其能够合成新颖的样本,这与只学习类别边界的判别模型不同。
该仓库实现了四种主要类型的生成模型
有关微调技术,请参阅微调技术。
来源: labml_nn/diffusion/ddpm/__init__.py docs/gan/index.html labml_nn/sketch_rnn/__init__.py labml_nn/capsule_networks/__init__.py
扩散模型是一类生成模型,其工作原理是逐步向数据添加噪声,然后学习如何逆转此过程。该仓库实现了基于论文《去噪扩散概率模型》的去噪扩散概率模型(DDPM)。
DDPM过程包括
来源:labml_nn/diffusion/ddpm/__init__.py172-287 labml_nn/diffusion/ddpm/experiment.py34-159
DDPM实现包括
DenoiseDiffusion类 - 管理前向和逆向过程的核心算法
q_xt_x0() - 获取时间 t 处噪声数据的分布q_sample() - 通过添加噪声采样噪声数据p_sample() - 通过预测噪声从逆向过程采样loss() - 计算训练的简化损失UNet模型 - 用于噪声预测的神经网络架构
训练配置
来源:labml_nn/diffusion/ddpm/__init__.py172-287 labml_nn/diffusion/ddpm/experiment.py80-156
该仓库实现了几种GAN变体,每种都在原始对抗训练方法的基础上有所改进。
GANs由两个相互竞争的网络组成
这些网络在对抗过程中进行训练,生成器在创建逼真样本方面的能力会随着判别器在检测方面的改进而提高。
该仓库包括以下GAN架构
每个实现都包含详细的注释,解释了其架构和训练方法。
SketchRNN是论文《A Neural Representation of Sketch Drawings》的实现。它是一个序列到序列的变分自编码器,通过预测一系列笔画来学习生成素描图。
该模型由以下部分组成
来源:labml_nn/sketch_rnn/__init__.py50-126 labml_nn/sketch_rnn/__init__.py197-238 labml_nn/sketch_rnn/__init__.py241-307
SketchRNN实现包括
StrokesDataset - 加载并预处理基于笔画的图画
EncoderRNN - 处理输入笔画
DecoderRNN - 生成新笔画
BivariateGaussianMixture - 对笔画坐标建模
损失函数
采样器 - 从模型生成新素描
来源:labml_nn/sketch_rnn/__init__.py50-126 labml_nn/sketch_rnn/__init__.py129-178 labml_nn/sketch_rnn/__init__.py197-307 labml_nn/sketch_rnn/__init__.py312-361 labml_nn/sketch_rnn/__init__.py370-455
胶囊网络代表了一种不同的特征表示方法,它使用神经元向量(胶囊)而不是标量神经元。虽然它们在传统意义上并非严格的生成模型,但它们包含一个作为正则化器使用的重建网络。
胶囊网络的实现包括
来源:labml_nn/capsule_networks/__init__.py39-71 labml_nn/capsule_networks/__init__.py73-134 labml_nn/capsule_networks/__init__.py136-186 labml_nn/capsule_networks/mnist.py29-97
挤压函数 - 在保留向量方向的同时归一化向量长度
v_j = (||s_j||² / (1 + ||s_j||²)) * (s_j / ||s_j||)
路由 - 实现胶囊间的动态路由
边际损失 - 分类损失函数
L_k = T_k max(0, m⁺ - ||v_k||)² + λ(1-T_k)max(0, ||v_k|| - m⁻)²
MNIST模型 - MNIST数字分类的完整实现
来源:labml_nn/capsule_networks/__init__.py39-71 labml_nn/capsule_networks/__init__.py73-134 labml_nn/capsule_networks/__init__.py136-186 labml_nn/capsule_networks/mnist.py29-97
每个生成模型实现都由仓库的训练基础设施支持,其中包括
配置类 - 实验的配置管理
损失函数 - 模型特定的损失函数
可视化 - 用于可视化生成样本的工具
实验跟踪 - 与labml集成以进行实验跟踪
来源:labml_nn/sketch_rnn/__init__.py458-611 labml_nn/diffusion/ddpm/experiment.py34-159 labml_nn/capsule_networks/mnist.py100-159
该仓库中的生成模型代表了数据生成的多种方法
每个实现都提供了详细的注释,解释了理论、架构和训练过程,使其成为理解和扩展生成模型的宝贵资源。