本文档解释了 TensorFlow Model Garden 中的训练循环架构,重点关注模型如何进行训练和评估。训练循环负责执行模型训练步骤、处理检查点、管理评估和写入摘要。有关为训练循环提供数据的输入管道的信息,请参阅 数据输入管道。
TensorFlow Model Garden 中的训练循环构建在 Orbit库构建,该库提供了一个灵活的框架,用于创建具有内置常见功能(如检查点、评估和摘要写入)的自定义训练循环。
来源
训练循环架构由几个关键组件协同工作构成。
来源
Orbit 中的 Controller 类管理外部训练循环。它协调训练、评估、检查点管理和摘要写入。Controller 提供了一些关键方法:
train():运行指定步数的训练。evaluate():运行评估。train_and_evaluate():交替进行训练和评估。evaluate_continuously():监视一个目录并评估新的检查点。来源
实际的训练和评估逻辑由继承自 AbstractTrainer 和 AbstractEvaluator 的类实现。Model Garden 在 Trainer 类中提供了标准实现。
AbstractTrainer:定义 train() 方法的接口。StandardTrainer:提供训练循环结构的实现(开始、步骤、结束)。Trainer:与任务和优化器集成的 Model Garden 实现。来源
训练循环在两个层面运行:
外层循环由 Controller 管理,负责:
来源
内层循环由 Trainer 实现,并执行实际的训练步骤。
来源
以下方法构成了训练循环的骨干:
来源
来源
训练循环通过 ExperimentConfig 进行配置,其中包含:
TrainerConfig:控制训练行为(每个循环的步数、检查点间隔等)。TaskConfig:定义要执行的任务。RuntimeConfig:指定运行时设置(分布式策略等)。来源
Controller 管理训练过程并连接不同的组件。
来源
训练的主要入口点是 run_experiment 函数,它将所有内容连接在一起。
来源
当调用 train_and_evaluate 时,Controller 会:
eval_interval 步。来源
训练循环与任务抽象集成,任务抽象定义了模型构建、损失计算和度量。
Task 类提供了训练循环调用的钩子。
train_step():执行一次训练步骤。validation_step():执行一次评估步骤。build_metrics():创建用于跟踪性能的度量对象。build_inputs():创建用于训练/验证的数据集。来源
训练循环包含多种优化选项:
为了获得更好的性能,尤其是在 TPU 上,训练循环可以使用 tf.function 和 tf.while_loop。
来源
训练循环通过 TensorFlow 的分布式策略支持分布式训练。
来源
要自定义训练循环,您可以:
Task 子类,该子类定义了模型、损失和度量。ExperimentConfig 配置训练循环。来源
TensorFlow Model Garden 中的训练循环提供了灵活、高性能的框架,具有以下主要特点:
这种架构使得在利用训练循环管理的通用基础设施的同时,轻松实现自定义模型变得更加容易。