菜单

训练流程

相关源文件

本文档概述了 Stable Diffusion v1 的训练流水线,解释了如何从头开始训练或微调模型。它涵盖了训练数据准备、配置设置和训练运行的执行。有关模型架构的信息,请参阅模型架构

1. 训练流水线概述

Stable Diffusion 训练流水线建立在 PyTorch Lightning 之上,并采用两阶段方法

  1. 自编码器训练:首先,训练自编码器(VAE)组件,将其图像压缩到潜在空间并进行重构。
  2. 潜在扩散模型训练:然后,在压缩的潜在空间中使用文本条件进行扩散模型的训练。

训练的入口点是 main.py,它负责配置加载、模型初始化、数据准备和训练执行。

来源: main.py418-742 Stable_Diffusion_v1_Model_Card.md83-113

2. 配置系统

训练依赖 YAML 配置文件来定义模型架构、数据集、优化参数和训练设置。配置系统使用 OmegaConf 以实现灵活性和分层组织。

配置文件指定

  • 模型架构和参数
  • 数据集路径和预处理
  • 训练超参数(学习率、批量大小)
  • 日志记录和检查点设置

启动训练的示例命令

来源: main.py124-132 main.py512-520

3. 数据流水线

数据流水线负责训练数据的加载、预处理和批处理。

3.1 数据集配置

数据集在 YAML 文件中配置,并通过 DataModuleFromConfig 类实例化。该系统支持各种数据集类型,包括用于大规模训练的可迭代数据集。

来源: main.py132-237

3.2 训练数据

最初的 Stable Diffusion 模型训练使用了

  1. LAION-2B-en:LAION-5B 的英文子集
  2. LAION-high-resolution:1.7 亿个分辨率 ≥ 1024x1024 的示例
  3. LAION-aesthetics v2 5+:经过过滤的美学分数 > 5.0 的子集

每个数据集都使用特定于训练阶段的过滤参数和转换进行加载。

来源: Stable_Diffusion_v1_Model_Card.md85-89 Stable_Diffusion_v1_Model_Card.md99-106

4. 训练执行

4.1 训练初始化

训练过程通过 PyTorch Lightning 的 Trainer 类启动,该类负责分布式训练、检查点和日志记录。

系统根据批量大小、GPU 数量和梯度累积步数来调整学习率。

learning_rate = accumulate_grad_batches * num_gpus * batch_size * base_lr

来源: main.py673-694 main.py716-724

4.2 回调和监控

训练通过以下几个回调进行监控:

  1. SetupCallback:创建日志目录并保存配置。
  2. ImageLogger:在训练期间定期记录生成的图像。
  3. ModelCheckpoint:根据指定指标保存模型检查点。
  4. LearningRateMonitor:跟踪学习率变化。
  5. CUDACallback:监控 GPU 内存使用情况和训练时间。

来源: main.py240-415 main.py592-657

4.3 训练参数

原始 Stable Diffusion 训练使用了以下参数:

参数
硬件32 x 8 x A100 GPU
优化器AdamW
梯度累积2
有效批量大小2048 (32 x 8 x 2 x 4)
学习率0.0001(具有 10,000 步预热)
第一阶段训练256x256 分辨率下 237k 步 + 512x512 分辨率下 194k 步
最终阶段训练额外步数,文本条件 dropout 为 10%

来源: Stable_Diffusion_v1_Model_Card.md107-113 Stable_Diffusion_v1_Model_Card.md99-106

5. 模型检查点管理

Stable Diffusion 检查点包含所有模型组件的权重。训练流水线会定期以及在中断时保存检查点。

5.1 检查点结构

5.2 检查点配置

训练流水线提供了一些与检查点相关的配置:

  • 常规检查点:按 epoch 间隔保存。
  • 最佳检查点:根据监控的指标保存。
  • 步数检查点:每 N 步训练保存一次。
  • 最后一个检查点:始终保存,用于恢复训练。

来源: main.py567-588 main.py635-649

6. 微调工作流

要在自定义数据集上微调 Stable Diffusion,请遵循以下步骤:

6.1 准备自定义数据集

  1. 创建一个数据集类,该类返回图像-文本对。
  2. 在 YAML 配置文件中配置数据集。
  3. 确保进行适当的预处理(图像调整大小、文本标记化)。

6.2 微调配置

创建一个配置文件,其中:

  1. 加载基础 Stable Diffusion 检查点。
  2. 配置你的自定义数据集。
  3. 设置适当的超参数(微调时使用较低的学习率)。

6.3 执行微调

使用以下参数运行微调:

关键微调参数

参数推荐值
基础学习率1e-5 到 1e-6
训练步数1,000 到 10,000
批大小4-32(取决于 GPU 内存)
梯度累积2-4
验证频率每 100-500 步

来源: main.py673-694 main.py716-724

7. 分布式训练

Stable Diffusion 训练流水线支持使用 PyTorch Lightning 的分布式策略在多个 GPU 和节点上进行分布式训练。

默认情况下,训练使用分布式数据并行 (DDP) 进行多 GPU 训练,这可以通过 PyTorch Lightning Trainer 进行配置。

来源: main.py521-530 main.py675-679

8. 硬件要求

从头开始训练 Stable Diffusion 需要大量的计算资源。

训练类型最低硬件推荐硬件
完整训练8x A100 40GB GPU32x A100 80GB GPU
微调2x RTX 3090 24GB4x A100 40GB GPU
概念微调1x RTX 3090 24GB2x RTX 3090 24GB

原始模型训练使用了 32 个节点,每个节点有 8 个 A100 GPU,运行了约 150,000 小时的计算时间。

来源: Stable_Diffusion_v1_Model_Card.md124-132