本文档介绍了如何使用 Keras 3 模型实现多 GPU 训练。它涵盖了 PyTorch 和 JAX 后端中的同步数据并行方法。本文档重点介绍单主机、多设备设置,即多 GPU(或 JAX 中的 TPU)安装在单台机器上的情况,这是研究和小型工业工作流中最常见的配置。
有关训练期间内存优化的信息,请参阅内存优化。
Keras 3 通过数据并行支持分布式训练,其中单个模型在多个设备上复制。每个设备独立处理不同的数据批次,并将结果合并以更新模型权重。
同步数据并行确保所有模型副本在每个批次后保持同步,与单设备训练保持相同的收敛行为。主要特点包括
来源:guides/distributed_training_with_torch.py10-35 guides/distributed_training_with_jax.py10-35
Keras 3 的多后端架构需要根据所使用的后端采用不同的分布式训练方法。两种主要的实现是
来源:guides/distributed_training_with_torch.py45-46 guides/distributed_training_with_jax.py45-46
使用 PyTorch 后端时,Keras 利用 PyTorch 的 DistributedDataParallel (DDP) 功能,该功能实现了基于进程的分布式训练方法。
DistributedSampler 确保每个进程接收到训练数据的不同子集。DistributedDataParallel 封装 Keras 模型以处理设备间的梯度同步。来源:guides/distributed_training_with_torch.py140-268 examples/demo_torch_multi_gpu.py114-213
要使用 PyTorch 后端实现多 GPU 训练
将 Keras 后端设置为 PyTorch
创建一个按设备启动的功能,该功能
DistributedSampler 准备数据DistributedDataParallel 封装模型使用 PyTorch 的多进程工具启动多个进程
来源:guides/distributed_training_with_torch.py186-268 examples/demo_torch_multi_gpu.py152-213
使用 JAX 后端时,Keras 利用 JAX 的分片能力,这与 PyTorch 基于进程的方法不同。JAX 不创建多个进程,而是使用单个进程智能地将数据和操作分片到各个设备上。
mesh_utils.create_device_mesh 创建物理设备的逻辑排列。jax.device_put 根据分片规范放置变量。来源:guides/distributed_training_with_jax.py113-270
要使用 JAX 后端实现多 GPU 或 TPU 训练
将 Keras 后端设置为 JAX
创建设备网格
定义分片策略
根据分片策略放置变量和数据
实现训练步骤函数并进行 JIT 编译。
在训练循环中,在将每个数据批次传递给训练步骤之前进行分片。
来源:guides/distributed_training_with_jax.py231-260
下表总结了 PyTorch 和 JAX 分布式训练方法之间的主要区别
| 方面 | PyTorch 后端 | JAX 后端 |
|---|---|---|
| 并行方法 | 基于进程(多个 Python 进程) | 基于数组(单个进程,带有分片数组) |
| 关键 API | DistributedDataParallel | jax.sharding |
| 数据分布 | DistributedSampler | 使用 PartitionSpec 进行显式分片 |
| 变量处理 | 在 DistributedDataParallel 中隐式处理 | 使用 NamedSharding 进行显式复制 |
| 通信 | NCCL 后端 | 基于 XLA 的跨设备操作 |
| 代码复杂度 | 更多样板代码(进程设置) | 更具函数性(显式分片) |
来源:guides/distributed_training_with_torch.py140-268 guides/distributed_training_with_jax.py113-270
后端选择:选择最符合您的基础设施和熟悉度的后端
批次大小:从单 GPU 训练转向多 GPU 训练时,您可以
学习率缩放:对于更大的批次大小,请考虑缩放学习率
GPU 内存优化:有关使用大型模型管理 GPU 内存的技术,请参阅内存优化。
来源:guides/distributed_training_with_torch.py10-35 guides/distributed_training_with_jax.py10-35