菜单

内存优化

相关源文件

本文档解释了 Keras 的内存优化功能,特别关注通过 RematScope 类和 remat 函数实现的重计算(梯度检查点)功能。这些工具以增加计算量为代价,帮助减少模型训练期间的内存使用,从而在内存资源有限的情况下能够训练更大的模型或使用更大的批次大小。

重计算简介

重计算,也称为梯度检查点,是一种在神经网络训练期间以计算换取内存的技术。它不是存储反向传播所需的所有中间激活,而是在需要时重新计算部分激活,从而以额外计算时间为代价减少峰值内存使用。

来源: keras/src/backend/common/remat.py11-18

RematScope:配置重计算

Keras 提供了 RematScope 上下文管理器,用于在模型执行期间启用和配置重计算。

RematScope 类提供了多种重计算模式,以控制哪些操作应该被重计算:

模式描述用例
"full"(完整)对所有支持的操作应用重计算为整个模型节省最大内存
"activations"(激活)仅重计算包含 keras.activations 的激活优化具有大量激活的模型的内存
"larger_than"(大于)重计算输出超过 output_size_threshold 的层仅针对内存密集型层
"list_of_layers"(层列表)按名称重计算特定层对特定层进行细粒度控制
None(无)禁用重计算默认行为

来源: keras/src/backend/common/remat.py8-116

使用 RematScope

RematScope 在模型执行期间(而非模型创建期间)作为上下文管理器应用。以下是该系统在实践中的工作方式:

示例使用模式

应用于特定层

条件重计算

嵌套范围

来源: keras/src/backend/common/remat.py40-75

remat 函数

为了更细粒度的控制,Keras 提供了 remat 函数,可以直接将重计算应用于特定的函数或操作。

使用示例

来源: keras/src/backend/common/remat.py137-186

实现细节

RematScope 系统通过在全局状态中维护一个活动重计算范围的堆栈来工作。当进入一个新的范围时,它会被推到堆栈上。当检查是否应该应用重计算时,系统会查看最近添加的范围(堆栈顶部)。

实际的重计算是在后端层面实现的。remat 函数委托给 backend.core.remat(),后者提供了 TensorFlow、JAX 或 PyTorch 的后端特定实现。

来源: keras/src/backend/common/remat.py98-134 keras/src/backend/common/remat.py186

性能考量

重计算涉及内存使用和计算时间之间的权衡

  • 内存优势:减少训练期间的峰值内存使用,对于大型模型或批次大小特别有用。
  • 计算成本:由于中间值的重计算而增加训练时间。

最佳的重计算策略取决于您的具体硬件限制和模型架构

  1. 对于内存受限的环境,使用 mode="full" 以最大化内存节省。
  2. 对于更平衡的方法,使用 mode="larger_than"mode="activations" 来针对特定的内存密集型操作。
  3. 为了进行细粒度控制,使用 mode="list_of_layers" 仅对内存消耗最大的层进行重计算。

来源: keras/src/backend/common/remat.py11-18 keras/src/backend/common/remat.py140-147

后端支持

并非所有后端都支持重计算。目前,numpyopenvino 后端不支持此功能。当使用这些后端时,重计算操作将不起作用。

来源: keras/src/backend/common/remat_test.py86-89