本文档解释了 Keras 的内存优化功能,特别关注通过 RematScope 类和 remat 函数实现的重计算(梯度检查点)功能。这些工具以增加计算量为代价,帮助减少模型训练期间的内存使用,从而在内存资源有限的情况下能够训练更大的模型或使用更大的批次大小。
重计算,也称为梯度检查点,是一种在神经网络训练期间以计算换取内存的技术。它不是存储反向传播所需的所有中间激活,而是在需要时重新计算部分激活,从而以额外计算时间为代价减少峰值内存使用。
来源: keras/src/backend/common/remat.py11-18
Keras 提供了 RematScope 上下文管理器,用于在模型执行期间启用和配置重计算。
RematScope 类提供了多种重计算模式,以控制哪些操作应该被重计算:
| 模式 | 描述 | 用例 |
|---|---|---|
"full"(完整) | 对所有支持的操作应用重计算 | 为整个模型节省最大内存 |
"activations"(激活) | 仅重计算包含 keras.activations 的激活 | 优化具有大量激活的模型的内存 |
"larger_than"(大于) | 重计算输出超过 output_size_threshold 的层 | 仅针对内存密集型层 |
"list_of_layers"(层列表) | 按名称重计算特定层 | 对特定层进行细粒度控制 |
None(无) | 禁用重计算 | 默认行为 |
来源: keras/src/backend/common/remat.py8-116
RematScope 在模型执行期间(而非模型创建期间)作为上下文管理器应用。以下是该系统在实践中的工作方式:
来源: keras/src/backend/common/remat.py40-75
为了更细粒度的控制,Keras 提供了 remat 函数,可以直接将重计算应用于特定的函数或操作。
来源: keras/src/backend/common/remat.py137-186
RematScope 系统通过在全局状态中维护一个活动重计算范围的堆栈来工作。当进入一个新的范围时,它会被推到堆栈上。当检查是否应该应用重计算时,系统会查看最近添加的范围(堆栈顶部)。
实际的重计算是在后端层面实现的。remat 函数委托给 backend.core.remat(),后者提供了 TensorFlow、JAX 或 PyTorch 的后端特定实现。
来源: keras/src/backend/common/remat.py98-134 keras/src/backend/common/remat.py186
重计算涉及内存使用和计算时间之间的权衡
最佳的重计算策略取决于您的具体硬件限制和模型架构
mode="full" 以最大化内存节省。mode="larger_than" 或 mode="activations" 来针对特定的内存密集型操作。mode="list_of_layers" 仅对内存消耗最大的层进行重计算。来源: keras/src/backend/common/remat.py11-18 keras/src/backend/common/remat.py140-147
并非所有后端都支持重计算。目前,numpy 和 openvino 后端不支持此功能。当使用这些后端时,重计算操作将不起作用。