菜单

核心框架

相关源文件

核心框架提供了在 TensorFlow Model Garden 中训练和评估机器学习模型的底层基础架构。它将常见的机器学习工作流抽象为一组标准化的组件,负责数据摄取、模型构建、训练、评估和配置管理。本页文档介绍了此框架的架构和组件。

有关在此框架之上构建的特定模型实现的更多信息,请参阅 NLP 模型计算机视觉模型推荐模型 的页面。

关键组件

核心框架由多个关键组件组成,它们协同工作,提供灵活、可配置的训练流水线。

来源

配置系统

该框架使用分层配置系统,可轻松进行参数管理和实验配置。

配置系统使用 Python 数据类构建,支持类型检查和 IDE 自动补全。配置可以以编程方式创建,也可以从 JSON/YAML 文件加载。

来源

关键配置类

  1. ExperimentConfig:顶级配置,包含任务、训练器和运行时配置。
  2. TaskConfig:定义模型、数据集和初始化设置。
  3. TrainerConfig:包含训练参数,如步数、验证间隔和优化器设置。
  4. RuntimeConfig:指定执行环境参数,如分布式策略。

配置示例(代码中)

来源

任务接口

Task 类是一个核心抽象,它定义了如何构建模型、创建数据集、计算损失和评估性能。

用户通过继承 Task 并实现所需方法来实施其特定任务。这提供了一个一致的接口,同时允许实现细节的灵活性。

来源

关键任务方法

  • build_model():创建并返回模型架构。
  • build_inputs():返回用于训练/验证的数据集或数据集函数。
  • build_losses():根据模型输出和标签计算损失。
  • build_metrics():返回评估指标列表。
  • train_step():执行前向传播,计算损失并应用梯度。
  • validation_step():执行验证前向传播。
  • create_optimizer():创建带有学习率调度器的优化器。

来源

输入流水线

输入流水线负责从各种来源读取数据、预处理数据,并创建用于训练和评估的批处理数据集。

InputReader 类提供了一个灵活的接口,用于从各种来源读取数据、应用转换和创建批处理数据集。

来源

InputReader 的主要特点

  • 支持 TFRecord 文件、TFDS 数据集和自定义数据源。
  • 分布式数据集创建,用于多设备训练。
  • 可配置的预处理、批处理和混洗。
  • 可选的 TF Data Service 集成,以提高性能。
  • 支持数据集混合和加权采样。

来源

训练框架

训练框架将所有组件整合在一起,以运行完整的训练和评估循环。其核心是 Trainer 类,它处理内部训练循环,而 Orbit 的 Controller 则管理外部循环。

来源

实验执行流程

  1. 实验配置:从标志或参数解析配置。
  2. 模型和任务创建:根据配置构建模型和任务。
  3. 训练器初始化:使用模型、任务和优化器创建训练器。
  4. 循环执行:根据模式运行训练或评估循环。
  5. 检查点和摘要写入:保存检查点并写入摘要。

该框架支持多种执行模式:

  • train:训练模型直到达到指定的步数。
  • eval:在验证数据集上评估模型。
  • train_and_eval:交错训练和评估。
  • continuous_eval:持续监视新检查点并对其进行评估。

来源

Orbit 集成

核心框架与 Orbit 库集成,用于管理训练循环并提供标准化的抽象。

Orbit 在内部训练/评估循环和外部循环管理之间提供了清晰的隔离。核心框架扩展并实现了这些抽象,以提供完整的训练和评估流程。

来源

控制器

Orbit 的 Controller 管理外部训练和评估循环,负责处理

  • 检查点保存和恢复
  • 摘要写入
  • 训练和评估交错
  • 连续评估

来源

StandardTrainer 和 StandardEvaluator

StandardTrainerStandardEvaluator 类提供了抽象的训练器和评估器接口的结构化实现

  • 将训练/评估循环拆分为“begin”、“step”和“end”方法
  • 支持 TF while 循环以提高性能
  • TPU 特定优化

来源

训练执行流程

下图说明了运行训练实验时的执行流程

来源

执行模式和示例

核心框架通过 run_experiment 函数支持多种执行模式

  1. 仅训练:为指定的步数训练模型
  2. 仅评估:在验证数据集上评估模型
  3. 训练和评估:以指定的间隔交错训练和评估
  4. 连续评估:持续监控新的检查点并评估它们

示例工作流程

来源

总结

TensorFlow Model Garden 中的核心框架为实现、训练和评估机器学习模型提供了一个全面的基础。关键组件包括:

  1. 任务接口:定义如何构建模型、数据集、损失和指标
  2. 配置系统:提供一个分层、类型检查的配置系统
  3. 输入管道:处理数据加载、预处理和批量创建
  4. 训练框架:与 Orbit 集成以提供标准化的训练循环
  5. 实验运行器:管理训练和评估的执行

这些组件共同提供了一种结构化、可配置的机器学习实验方法,同时允许在实现细节上具有灵活性。

来源