菜单

数据输入管道

相关源文件

TensorFlow Model Garden 中的数据输入管道提供了一种标准化且灵活的方式,可在训练和评估期间加载、预处理和馈送数据至模型。本文档将介绍数据管道的核心组件、它们如何协同工作,以及如何为不同的用例配置它们。

有关消耗此数据的训练循环的信息,请参阅训练循环

概述

数据输入管道旨在高效地处理各种数据源,根据模型要求进行预处理,并为分布式训练提供优化数据馈送。该管道支持不同的数据格式,提供内置预处理函数,并提供性能优化以实现可扩展的训练。

来源: official/core/input_reader.py214-599

核心组件

数据输入管道由几个关键组件组成,它们协同工作,以高效地处理用于模型训练和评估的数据。

InputReader

InputReader 类是创建数据管道的主要接口。它负责从各种来源读取数据、解码、预处理和批处理。

主要职责

  • 从 TFRecord 文件或 TensorFlow Datasets (TFDS) 读取
  • 应用解码、预处理和批处理函数
  • 处理数据分布策略
  • 支持分布式训练的分片
  • 与 TF Data Service 集成

来源: official/core/input_reader.py214-599

DataConfig

DataConfig 类定义了数据输入的配置。它包含数据源、批大小、分片和其他管道设置的参数。

重要的配置选项

  • input_path:输入数据文件或模式的路径
  • tfds_name:使用 TFDS 时的 TensorFlow Dataset 名称
  • global_batch_size:所有副本的总批大小
  • is_training:数据是否用于训练(影响打乱等)
  • shuffle_buffer_size:训练的打乱缓冲区大小
  • sharding:是否跨设备分片数据
  • cache:是否将数据集缓存在内存中

来源: official/core/config_definitions.py27-136

数据处理工作流程

完整的 数据处理工作流程包含多个阶段,将原始数据转换为模型就绪批次。

工作流程包括

  1. 读取:从文件或 TFDS 读取数据
  2. 解码:将序列化数据转换为张量
  3. 解析/预处理:将原始特征转换为模型输入
  4. 采样/过滤:可选的示例采样或过滤步骤
  5. 批处理:为模型消耗创建批次
  6. 分发:将批次分发到设备

来源: official/core/input_reader.py461-516

创建和配置输入管道

基本用法

以下是如何使用 InputReader 类创建基本输入管道

然后可以将数据集传递给模型进行训练或评估。

来源: official/core/input_reader.py221-266 official/core/train_lib.py97-121

自定义解码器和解析器

该管道支持自定义解码器和解析器函数,以处理特定的数据格式和预处理需求

  • 解码器函数:从序列化示例中提取原始特征
  • 解析器函数:将解码后的特征转换为模型输入
  • 转换和批处理函数:一起应用转换和批处理

例如,对于 NLP 任务,解码器可能提取文本和标签,而解析器可能处理分词和特征提取。

来源: official/nlp/data/classifier_data_lib.py30-744 official/core/input_reader.py461-516

分布式训练支持

输入管道为分布式训练场景提供了内置支持,处理跨多个设备或工作节点的数据分片和分发。

主要功能

  • 文件级别分片:将输入文件分配给工作节点
  • 示例级别分片:读取后分发示例
  • 根据文件数量自动选择最佳分片策略
  • 支持各种分布式策略(MirroredStrategy, TPUStrategy 等)

来源: official/core/input_reader.py77-148 official/core/input_reader.py372-401

性能优化

该管道包含多项用于高性能数据处理的优化措施

缓存和预取

缓存会将预处理后的数据集存储在内存中,以避免重复计算。预取则会将数据预处理与模型执行重叠。

并行处理

并行文件读取和预处理可提高吞吐量,尤其是在处理大量小文件时。

来源: official/core/input_reader.py221-266 official/core/config_definitions.py27-136

TF Data Service 集成

输入管道与 TensorFlow Data Service 集成,用于分布式数据预处理

TF Data Service 将数据预处理卸载到专用工作节点,这可以提高资源受限环境下的性能。

来源: official/core/input_reader.py323-357 official/core/input_reader.py517-567

特定领域管道

Model Garden 包含了针对不同领域的专用数据管道

NLP 数据管道

NLP 数据管道支持分类、回归和序列处理任务

关键组件

  • 从各种格式(CSV、TSV、JSON)解码文本
  • 分词(WordPiece 或 SentencePiece)
  • 为不同任务(分类、问答等)提取特征

来源: official/nlp/data/classifier_data_lib.py30-744 official/nlp/data/create_finetuning_data.py172-286

视觉数据管道

视觉数据管道处理图像分类、目标检测和分割任务

关键组件

  • 从各种格式解码图像
  • 调整大小和裁剪至模型输入维度
  • 数据增强(RandAugment、AutoAugment 等)
  • 归一化和格式化

来源: official/vision/dataloaders/input_reader.py29-157 official/vision/configs/common.py24-169

多源数据混合

输入管道支持使用加权采样组合多个数据源

这对于多任务学习、迁移学习或平衡不同数据源非常有用。

来源: official/vision/dataloaders/input_reader.py29-79

常见用例和示例

分类数据

对于分类任务,数据管道通常处理带标签的示例

来源: official/nlp/data/classifier_data_lib.py172-282

推荐数据

对于推荐任务,数据管道处理用户-物品交互数据

来源: official/core/input_reader.py214-599

与训练循环集成

数据输入管道通过 Task 类与训练循环集成,该类定义了模型在训练和评估期间如何消耗数据

Task 类实现了 build_inputs() 方法,该方法为特定任务创建适当的输入管道。然后,训练器调用适当的训练或验证步骤,这些步骤会消耗管道中的数据。

来源: official/core/base_task.py151-167 official/core/base_trainer.py373-401

结论

TensorFlow Model Garden 中的数据输入管道提供了一种灵活有效的方式来处理各种机器学习任务的数据。它支持不同的数据源、预处理操作和分布策略,使其适用于从小规模实验到大规模分布式训练的各种应用。

通过配置适当的 DataConfig 并根据需要实现自定义解码器和解析器函数,您可以创建满足您特定任务要求的有效数据管道。