TensorFlow Model Garden 中的数据输入管道提供了一种标准化且灵活的方式,可在训练和评估期间加载、预处理和馈送数据至模型。本文档将介绍数据管道的核心组件、它们如何协同工作,以及如何为不同的用例配置它们。
有关消耗此数据的训练循环的信息,请参阅训练循环。
数据输入管道旨在高效地处理各种数据源,根据模型要求进行预处理,并为分布式训练提供优化数据馈送。该管道支持不同的数据格式,提供内置预处理函数,并提供性能优化以实现可扩展的训练。
来源: official/core/input_reader.py214-599
数据输入管道由几个关键组件组成,它们协同工作,以高效地处理用于模型训练和评估的数据。
InputReader 类是创建数据管道的主要接口。它负责从各种来源读取数据、解码、预处理和批处理。
主要职责
来源: official/core/input_reader.py214-599
DataConfig 类定义了数据输入的配置。它包含数据源、批大小、分片和其他管道设置的参数。
重要的配置选项
input_path:输入数据文件或模式的路径tfds_name:使用 TFDS 时的 TensorFlow Dataset 名称global_batch_size:所有副本的总批大小is_training:数据是否用于训练(影响打乱等)shuffle_buffer_size:训练的打乱缓冲区大小sharding:是否跨设备分片数据cache:是否将数据集缓存在内存中来源: official/core/config_definitions.py27-136
完整的 数据处理工作流程包含多个阶段,将原始数据转换为模型就绪批次。
工作流程包括
来源: official/core/input_reader.py461-516
以下是如何使用 InputReader 类创建基本输入管道
然后可以将数据集传递给模型进行训练或评估。
来源: official/core/input_reader.py221-266 official/core/train_lib.py97-121
该管道支持自定义解码器和解析器函数,以处理特定的数据格式和预处理需求
例如,对于 NLP 任务,解码器可能提取文本和标签,而解析器可能处理分词和特征提取。
来源: official/nlp/data/classifier_data_lib.py30-744 official/core/input_reader.py461-516
输入管道为分布式训练场景提供了内置支持,处理跨多个设备或工作节点的数据分片和分发。
主要功能
来源: official/core/input_reader.py77-148 official/core/input_reader.py372-401
该管道包含多项用于高性能数据处理的优化措施
缓存会将预处理后的数据集存储在内存中,以避免重复计算。预取则会将数据预处理与模型执行重叠。
并行文件读取和预处理可提高吞吐量,尤其是在处理大量小文件时。
来源: official/core/input_reader.py221-266 official/core/config_definitions.py27-136
输入管道与 TensorFlow Data Service 集成,用于分布式数据预处理
TF Data Service 将数据预处理卸载到专用工作节点,这可以提高资源受限环境下的性能。
来源: official/core/input_reader.py323-357 official/core/input_reader.py517-567
Model Garden 包含了针对不同领域的专用数据管道
NLP 数据管道支持分类、回归和序列处理任务
关键组件
来源: official/nlp/data/classifier_data_lib.py30-744 official/nlp/data/create_finetuning_data.py172-286
视觉数据管道处理图像分类、目标检测和分割任务
关键组件
来源: official/vision/dataloaders/input_reader.py29-157 official/vision/configs/common.py24-169
输入管道支持使用加权采样组合多个数据源
这对于多任务学习、迁移学习或平衡不同数据源非常有用。
来源: official/vision/dataloaders/input_reader.py29-79
对于分类任务,数据管道通常处理带标签的示例
来源: official/nlp/data/classifier_data_lib.py172-282
对于推荐任务,数据管道处理用户-物品交互数据
来源: official/core/input_reader.py214-599
数据输入管道通过 Task 类与训练循环集成,该类定义了模型在训练和评估期间如何消耗数据
Task 类实现了 build_inputs() 方法,该方法为特定任务创建适当的输入管道。然后,训练器调用适当的训练或验证步骤,这些步骤会消耗管道中的数据。
来源: official/core/base_task.py151-167 official/core/base_trainer.py373-401
TensorFlow Model Garden 中的数据输入管道提供了一种灵活有效的方式来处理各种机器学习任务的数据。它支持不同的数据源、预处理操作和分布策略,使其适用于从小规模实验到大规模分布式训练的各种应用。
通过配置适当的 DataConfig 并根据需要实现自定义解码器和解析器函数,您可以创建满足您特定任务要求的有效数据管道。