菜单

多GPU训练

相关源文件

目的与范围

本文档介绍了如何使用 Keras 3 模型实现多 GPU 训练。它涵盖了 PyTorch 和 JAX 后端中的同步数据并行方法。本文档重点介绍单主机、多设备设置,即多 GPU(或 JAX 中的 TPU)安装在单台机器上的情况,这是研究和小型工业工作流中最常见的配置。

有关训练期间内存优化的信息,请参阅内存优化

数据并行概念

Keras 3 通过数据并行支持分布式训练,其中单个模型在多个设备上复制。每个设备独立处理不同的数据批次,并将结果合并以更新模型权重。

同步数据并行确保所有模型副本在每个批次后保持同步,与单设备训练保持相同的收敛行为。主要特点包括

  • 全局批次被分成更小的本地批次,供每个设备处理
  • 每个副本独立处理其本地批次
  • 梯度在所有副本之间同步
  • 所有副本统一更新其权重

来源:guides/distributed_training_with_torch.py10-35 guides/distributed_training_with_jax.py10-35

后端特定实现

Keras 3 的多后端架构需要根据所使用的后端采用不同的分布式训练方法。两种主要的实现是

来源:guides/distributed_training_with_torch.py45-46 guides/distributed_training_with_jax.py45-46

使用 PyTorch 后端进行多 GPU 训练

使用 PyTorch 后端时,Keras 利用 PyTorch 的 DistributedDataParallel (DDP) 功能,该功能实现了基于进程的分布式训练方法。

实现组件

关键系统组件

  1. 进程组初始化:每个 GPU 在单独的 Python 进程中运行,拥有自己的进程组。
  2. 数据分布DistributedSampler 确保每个进程接收到训练数据的不同子集。
  3. 模型封装DistributedDataParallel 封装 Keras 模型以处理设备间的梯度同步。
  4. 同步更新:梯度在每次反向传播后自动同步,确保所有副本保持一致。

来源:guides/distributed_training_with_torch.py140-268 examples/demo_torch_multi_gpu.py114-213

实施流程

要使用 PyTorch 后端实现多 GPU 训练

  1. 将 Keras 后端设置为 PyTorch

  2. 创建一个按设备启动的功能,该功能

    • 初始化分布式训练的进程组
    • 使用 DistributedSampler 准备数据
    • 创建并编译模型
    • 使用 DistributedDataParallel 封装模型
    • 运行训练循环
  3. 使用 PyTorch 的多进程工具启动多个进程

来源:guides/distributed_training_with_torch.py186-268 examples/demo_torch_multi_gpu.py152-213

使用 JAX 后端进行多 GPU 训练

使用 JAX 后端时,Keras 利用 JAX 的分片能力,这与 PyTorch 基于进程的方法不同。JAX 不创建多个进程,而是使用单个进程智能地将数据和操作分片到各个设备上。

实现组件

关键系统组件

  1. 设备网格:使用 mesh_utils.create_device_mesh 创建物理设备的逻辑排列。
  2. 分片策略:
    • 模型变量:在所有设备上复制(无分区)。
    • 输入数据:沿批次维度分片。
  3. 变量放置:使用 jax.device_put 根据分片规范放置变量。
  4. JIT 编译:训练步骤经过 JIT 编译以实现高效执行。

来源:guides/distributed_training_with_jax.py113-270

实施流程

要使用 JAX 后端实现多 GPU 或 TPU 训练

  1. 将 Keras 后端设置为 JAX

  2. 创建设备网格

  3. 定义分片策略

    • 对于变量:在所有设备上复制
    • 对于数据:沿批次维度分片
  4. 根据分片策略放置变量和数据

  5. 实现训练步骤函数并进行 JIT 编译。

  6. 在训练循环中,在将每个数据批次传递给训练步骤之前进行分片。

来源:guides/distributed_training_with_jax.py231-260

实现比较

下表总结了 PyTorch 和 JAX 分布式训练方法之间的主要区别

方面PyTorch 后端JAX 后端
并行方法基于进程(多个 Python 进程)基于数组(单个进程,带有分片数组)
关键 APIDistributedDataParalleljax.sharding
数据分布DistributedSampler使用 PartitionSpec 进行显式分片
变量处理DistributedDataParallel 中隐式处理使用 NamedSharding 进行显式复制
通信NCCL 后端基于 XLA 的跨设备操作
代码复杂度更多样板代码(进程设置)更具函数性(显式分片)

来源:guides/distributed_training_with_torch.py140-268 guides/distributed_training_with_jax.py113-270

最佳实践

  1. 后端选择:选择最符合您的基础设施和熟悉度的后端

    • PyTorch 后端:与 PyTorch 生态系统兼容
    • JAX 后端:支持 TPU 并可能实现更快的编译
  2. 批次大小:从单 GPU 训练转向多 GPU 训练时,您可以

    • 保持相同的全局批次大小并缩短训练时间
    • 增加全局批次大小以利用增加的计算能力(可能需要调整学习率)
  3. 学习率缩放:对于更大的批次大小,请考虑缩放学习率

    • 线性缩放规则:学习率与批次大小成比例缩放
    • 渐进式预热:以较小的学习率开始,并在早期 epoch 逐渐增加
  4. GPU 内存优化:有关使用大型模型管理 GPU 内存的技术,请参阅内存优化

来源:guides/distributed_training_with_torch.py10-35 guides/distributed_training_with_jax.py10-35