菜单

目标检测

相关源文件

TensorFlow Models 中的目标检测模块提供了一个全面的框架,用于构建、训练和部署目标检测模型。本页面描述了目标检测子系统的架构、组件和工作流程,主要侧重于 research/object_detection 目录中的研究实现。

有关模型训练和评估的信息,请参阅相应的训练文档。有关模型部署,请参阅导出和提供指南。

概述

目标检测系统支持多种检测架构(称为“元架构”)和特征提取器,为不同的性能需求提供了灵活的框架。该系统遵循模块化设计,在数据处理、模型定义以及训练/评估循环之间实现了清晰的分离。

来源

架构组件

目标检测系统由几个协同工作的关键组件组成

1. 元架构

该系统支持不同的元架构,每种都实现了不同的检测范式

  • CenterNet:基于关键点的检测,将对象表示为点
  • Faster R-CNN:带有区域提议和分类的两阶段检测器
  • SSD:用于高效检测的单阶段检测器
  • RFCN:基于区域的全卷积网络

每个元架构都继承自 DetectionModel 基类,并实现了特定的 predict()loss()postprocess() 方法。

来源

2. 特征提取器

特征提取器将输入图像转换为用于检测元架构的特征图

  • ResNet:各种深度(50、101、152),带 FPN 选项
  • MobileNet:用于移动设备部署的轻量级网络(v1、v2)
  • Hourglass:特别是用于 CenterNet 的 hourglass 网络
  • EfficientNet:具有复合缩放的 eficient 网络

特征提取器特定于每个元架构,并实现从不同尺度提取特征的方法。

来源

3. 目标分配

目标分配在训练期间将地面真实标注映射到模型预测目标

TargetAssigner 类处理

  • 区域相似度计算(IoU、距离)
  • 锚点与地面真实之间的匹配策略
  • 计算分类和边界框回归目标

来源

4. 数据管道

输入管道负责读取、预处理和批处理训练样本

  • 支持 tf.Examples 的 TFRecord 格式
  • 数据增强选项(随机翻转、裁剪、颜色失真)
  • 用于高效训练的批处理和预取

来源

CenterNet 架构

CenterNet 是一种基于关键点的目标检测算法,它将对象表示为点,并且是该框架中的关键元架构之一。

关键组件

  1. 特征提取:使用 Hourglass、ResNet 或 MobileNet 提取特征
  2. 对象中心预测:热图预测每个位置的对象中心的可能性
  3. 边界框大小预测:每个潜在对象的宽度和高度
  4. 边界框偏移预测:中心坐标的细化,以处理离散化误差

CenterNet 可通过额外的预测头进行扩展

  • 关键点估计
  • 实例分割
  • 3D 检测
  • 跟踪

来源

边界框预测和编码

目标检测系统使用不同的边界框编码和预测机制

基于锚点的模型(SSD、Faster R-CNN)

  • 在特征图上生成锚点框
  • 预测从锚点到地面真实框的偏移量
  • 使用边界框编码器在坐标之间进行转换

基于关键点的模型(CenterNet)

  • 将对象中心预测为关键点
  • 直接预测每个中心点的宽度和高度
  • 使用偏移量细化以实现亚像素精度

来源

训练流程

训练管道整合了所有组件

  1. 数据加载:使用 inputs.py 创建数据集
  2. 模型创建:使用 model_builder.py 构建模型
  3. 训练循环:在 model_lib_v2.py 中处理梯度更新
  4. 评估:定期评估模型性能

来源

配置系统

目标检测框架使用协议缓冲区进行配置

  • 管道配置:总体训练/评估设置
  • 模型配置:元架构特定设置
  • 输入配置:数据集和预处理设置

CenterNet 配置的示例组件

message CenterNet {
  // Number of classes to predict
  optional int32 num_classes = 1;
  
  // Feature extractor config
  optional CenterNetFeatureExtractor feature_extractor = 2;
  
  // Image resizer for preprocessing
  optional ImageResizer image_resizer = 3;
  
  // Object detection task configuration
  optional ObjectDetection object_detection_task = 4;
  
  // Object center prediction parameters
  optional ObjectCenterParams object_center_params = 5;
}

来源

执行流程

在运行目标检测训练时,系统遵循以下流程

  1. pipeline.config 解析配置
  2. 创建特征提取器和元架构
  3. 使用数据集构建器设置输入管道
  4. 初始化模型参数和优化器
  5. 执行带有梯度的训练循环
  6. 定期评估和检查点模型

对于推理,流程是

  1. 加载已保存的模型
  2. 预处理输入图像
  3. 运行模型预测
  4. 应用后处理
  5. 返回检测框、得分和类别

来源

依赖项和安装

目标检测模块需要多个依赖项

依赖项目的
TensorFlow ≥ 2.5Core ML 框架
tf-models-official官方 TF 模型
Pillow图像处理
lxmlXML 解析
Cython用于 COCO API
pycocotools用于 COCO 评估
tensorflow_io附加数据格式支持

安装通常通过 pip 处理

pip install -e ./research/object_detection/packages/tf2/

来源

总结

目标检测模块提供了一个灵活的框架,用于训练和部署不同的目标检测架构。通过将数据管道、模型架构和训练循环分开,它允许研究人员和从业人员在保持系统其余部分不变的情况下尝试不同的组件。

模块化设计支持

  • 多种检测元架构
  • 各种特征提取器
  • 不同的输入数据格式
  • 自定义损失函数
  • TensorFlow 1.x 和 2.x 兼容性

这种灵活性使其成为计算机视觉研究和实际应用中的强大工具。