菜单

训练循环

相关源文件

本文档解释了 TensorFlow Model Garden 中的训练循环架构,重点关注模型如何进行训练和评估。训练循环负责执行模型训练步骤、处理检查点、管理评估和写入摘要。有关为训练循环提供数据的输入管道的信息,请参阅 数据输入管道

概述

TensorFlow Model Garden 中的训练循环构建在 Orbit库构建,该库提供了一个灵活的框架,用于创建具有内置常见功能(如检查点、评估和摘要写入)的自定义训练循环。

来源

架构

训练循环架构由几个关键组件协同工作构成。

来源

控制器

Orbit 中的 Controller 类管理外部训练循环。它协调训练、评估、检查点管理和摘要写入。Controller 提供了一些关键方法:

  • train():运行指定步数的训练。
  • evaluate():运行评估。
  • train_and_evaluate():交替进行训练和评估。
  • evaluate_continuously():监视一个目录并评估新的检查点。

来源

训练器和评估器

实际的训练和评估逻辑由继承自 AbstractTrainerAbstractEvaluator 的类实现。Model Garden 在 Trainer 类中提供了标准实现。

  • AbstractTrainer:定义 train() 方法的接口。
  • StandardTrainer:提供训练循环结构的实现(开始、步骤、结束)。
  • Trainer:与任务和优化器集成的 Model Garden 实现。

来源

训练循环流程

训练循环在两个层面运行:

外层循环 (Controller)

外层循环由 Controller 管理,负责:

  1. 协调训练步骤
  2. 周期性检查点保存
  3. 调度评估
  4. 写入摘要

来源

内层循环 (Trainer)

内层循环由 Trainer 实现,并执行实际的训练步骤。

来源

训练循环中的关键方法

以下方法构成了训练循环的骨干:

Controller.train()

来源

Trainer.train_step()

来源

关键组件

ExperimentConfig

训练循环通过 ExperimentConfig 进行配置,其中包含:

  • TrainerConfig:控制训练行为(每个循环的步数、检查点间隔等)。
  • TaskConfig:定义要执行的任务。
  • RuntimeConfig:指定运行时设置(分布式策略等)。

来源

Orbit Controller

Controller 管理训练过程并连接不同的组件。

来源

训练与评估工作流程

训练的主要入口点是 run_experiment 函数,它将所有内容连接在一起。

来源

当调用 train_and_evaluate 时,Controller 会:

  1. 训练 eval_interval 步。
  2. 运行评估。
  3. 重复此过程,直到达到目标训练步数。

来源

任务集成

训练循环与任务抽象集成,任务抽象定义了模型构建、损失计算和度量。

Task 类提供了训练循环调用的钩子。

  • train_step():执行一次训练步骤。
  • validation_step():执行一次评估步骤。
  • build_metrics():创建用于跟踪性能的度量对象。
  • build_inputs():创建用于训练/验证的数据集。

来源

性能优化

训练循环包含多种优化选项:

TF Function 和 While Loop

为了获得更好的性能,尤其是在 TPU 上,训练循环可以使用 tf.functiontf.while_loop

来源

分布式训练

训练循环通过 TensorFlow 的分布式策略支持分布式训练。

来源

定制

要自定义训练循环,您可以:

  1. 创建一个自定义的 Task 子类,该子类定义了模型、损失和度量。
  2. 通过 ExperimentConfig 配置训练循环。
  3. 创建在每次训练循环后运行的自定义训练操作。

来源

总结

TensorFlow Model Garden 中的训练循环提供了灵活、高性能的框架,具有以下主要特点:

  1. 将外层循环(检查点、评估)与内层循环(实际训练)分离。
  2. 支持不同策略的分布式训练。
  3. 为 TPU 和 GPU 提供优化选项。
  4. 与基于任务的模型构建和训练抽象集成。
  5. 处理检查点、度量和摘要写入等常见任务。

这种架构使得在利用训练循环管理的通用基础设施的同时,轻松实现自定义模型变得更加容易。