菜单

NLP 任务

相关源文件

本文档概述了 TensorFlow Model Garden 中实现的自然语言处理 (NLP) 任务。这些任务代表了特定的 NLP 功能,每个任务都有专门的架构、数据处理管道、评估指标和预测工具。有关底层 Transformer 架构的信息,请参阅 Transformer 架构,有关数据处理的详细信息,请参阅 NLP 数据处理

NLP 任务概述

TensorFlow Model Garden 实现了多个核心 NLP 任务,每个任务都封装在一个继承自 base_task.Task 的专用任务类中。这些任务为训练、评估和推理提供了统一的接口。

任务架构关系

下图显示了 NLP 任务与其组件之间的关系

来源: official/nlp/tasks/sentence_prediction.py official/nlp/tasks/question_answering.py official/nlp/tasks/tagging.py official/nlp/tasks/masked_lm.py

通用任务结构

所有 NLP 任务都遵循一致的模式

来源: official/nlp/tasks/sentence_prediction.py63-166 official/nlp/tasks/question_answering.py76-204 official/nlp/tasks/tagging.py81-205 official/nlp/tasks/masked_lm.py55-210

句子预测任务

句子预测任务处理句子级别的分类或回归,支持情感分析、自然语言推理、释义检测和语义文本相似度等应用。

架构与实现

来源: official/nlp/tasks/sentence_prediction.py63-166 official/nlp/tasks/sentence_prediction.py255-319

配置

句子预测任务使用 SentencePredictionConfig 进行配置

来源: official/nlp/tasks/sentence_prediction.py48-60

支持的指标

句子预测任务支持多个评估指标

指标类型描述用例
accuracy标准的分类准确率多类别分类
f1F1分数二元分类
matthews_corrcoefMatthews 相关系数二元分类
pearson_spearman_corrPearson 和 Spearman 相关性的平均值回归/相似性任务

来源: official/nlp/tasks/sentence_prediction.py36-37 official/nlp/tasks/sentence_prediction.py193-226

使用示例

句子预测任务可用于预测输入句子的类别标签或回归值

来源: official/nlp/tasks/sentence_prediction.py255-319 official/nlp/tasks/sentence_prediction_test.py236-267

问答任务

问答任务实现了抽取式问答,模型从文本段落中识别出回答给定问题的文本片段。它支持 SQuAD v1.1 和 v2.0 等数据集。

架构与实现

来源: official/nlp/tasks/question_answering.py76-204 official/nlp/tasks/question_answering.py329-458

配置

问答任务使用 QuestionAnsweringConfig 进行配置

来源: official/nlp/tasks/question_answering.py48-62

评估和预测

问答任务的评估使用

  1. 精确匹配:预测与真实答案完全匹配的百分比
  2. F1 分数:预测与真实答案之间的平均词语重叠度
  3. Has-Answer 指标:针对可回答问题(SQuAD v2.0)的单独指标

来源: official/nlp/tasks/question_answering.py461-502 official/nlp/tasks/question_answering_test.py156-174

标注任务

标注任务实现了诸如命名实体识别 (NER) 和词性 (POS) 标注等任务的 token 级别分类。

架构与实现

来源: official/nlp/tasks/tagging.py81-205 official/nlp/tasks/tagging.py208-265

配置

标注任务使用 TaggingConfig 进行配置

来源: official/nlp/tasks/tagging.py43-58

评估和处理 Token 标签

标注任务使用 seqeval 包通过以下指标评估性能

指标描述
f1实体级别的 F1 分数
precision实体级别的准确率
recall实体级别的召回率
accuracy实体级别的准确率

标注任务的一个特殊功能是处理 Tokenized 文本中的子词 Token

来源: official/nlp/tasks/tagging.py61-78 official/nlp/tasks/tagging.py196-204

掩码语言模型任务

掩码语言模型任务实现了 Transformer 模型(如 BERT)的预训练目标,即掩盖一部分输入 Token,模型需要对其进行预测。

架构与实现

来源: official/nlp/tasks/masked_lm.py55-210

配置

掩码语言模型任务使用 MaskedLMConfigPretrainerConfig 进行配置

来源: official/nlp/tasks/masked_lm.py31-52 official/nlp/configs/bert.py36-47

预训练过程

掩码语言模型任务支持

  1. 掩码语言模型:预测掩码标记
  2. 下一个句子预测:可选的次要目标

训练数据包含掩码标记,以及可选的下一个句子标签

来源: official/nlp/tasks/masked_lm.py113-123 official/nlp/tasks/masked_lm.py158-156

通用组件和实用程序

任务注册和工厂模式

NLP任务使用任务工厂模式进行注册

这使得任务可以根据其配置自动实例化。

来源: official/nlp/tasks/sentence_prediction.py63 official/nlp/tasks/question_answering.py76 official/nlp/tasks/tagging.py81 official/nlp/tasks/masked_lm.py55

数据加载器工厂

与任务工厂类似,数据加载器通过工厂进行注册和访问

来源: official/nlp/tasks/sentence_prediction.py137 official/nlp/tasks/question_answering.py217-218 official/nlp/tasks/tagging.py139 official/nlp/tasks/masked_lm.py131 official/nlp/data/data_loader_factory.py55-58

编码器集成

任务可以使用来自不同来源的编码器

  1. TensorFlow Hub:从TF Hub加载预训练模型
  2. 检查点:从已保存的检查点初始化
  3. 从零开始:构建新的编码器

来源: official/nlp/tasks/sentence_prediction.py81-85 official/nlp/tasks/utils.py23-46

预测工具

每个任务都提供了一个predict函数,用于对新数据进行推理

预测流程通常遵循以下步骤:

  1. 预处理数据
  2. 通过模型运行推理
  3. 后处理输出(例如,提取答案跨度,将logits转换为类别)

来源: official/nlp/tasks/sentence_prediction.py255-319 official/nlp/tasks/question_answering.py461-502 official/nlp/tasks/tagging.py208-265

NLP任务摘要

任务主要类主要模型评估指标主要应用程序
句子预测SentencePredictionTaskBertClassifierXLNetClassifier准确率、F1、Matthews、相关性情感分析、NLI、释义
问答QuestionAnsweringTaskBertSpanLabelerXLNetSpanLabeler精确匹配、F1 分数SQuAD、抽取式QA
标记TaggingTaskBertTokenClassifierF1、精确率、召回率、准确率NER、词性标记
掩码LMMaskedLMTaskBertPretrainerV2MLM准确率、NSP准确率预训练

这些任务构成了TensorFlow Model Garden的核心NLP功能,提供了标准化的实现,易于配置、训练和评估。

来源: official/nlp/tasks/sentence_prediction.py official/nlp/tasks/question_answering.py official/nlp/tasks/tagging.py official/nlp/tasks/masked_lm.py