菜单

计算机视觉模型

相关源文件

本文档全面概述了 TensorFlow Model Garden 中提供的计算机视觉模型和组件。它涵盖了目标检测架构、图像分类骨干网络、语义分割模型以及支持这些模型的共享预处理基础设施。

有关 NLP 模型的信息,请参阅 NLP 模型,有关推荐系统的信息,请参阅 推荐模型

架构概述

TensorFlow Model Garden 中的计算机视觉生态系统包含多个相互关联的组件

来源

  • [research/object_detection/meta_architectures/center_net_meta_arch.py]
  • [research/object_detection/builders/model_builder.py]
  • [official/vision/ops/augment.py]
  • [official/vision/ops/preprocess_ops.py]

目标检测模型

目标检测模型识别和定位图像中的对象,提供类别标签和边界框坐标。

模型架构比较

来源

  • [research/object_detection/meta_architectures/center_net_meta_arch.py]
  • [official/vision/configs/retinanet.py]
  • [research/object_detection/builders/model_builder.py]

CenterNet

CenterNet 是一种基于关键点的目标检测架构,在“Objects as Points”论文中有描述。它预测对象的中心点及其属性(大小、偏移量等)。

关键组件

  • 特征提取器:多种选项,包括 Hourglass、ResNet 和 MobileNet
  • 预测头:生成对象中心、对象大小和偏移量的热图
  • 后处理:将热图峰值转换为对象检测结果

CenterNet 实现了一种目标分配策略,在对象中心创建高斯峰值

支持的特征提取器

  • Hourglass(10/20/32/52/104 层变体)
  • ResNet v1 和 v2
  • MobileNet v2(带 FPN 和不带 FPN)

来源

  • [research/object_detection/meta_architectures/center_net_meta_arch.py:15-43]
  • [research/object_detection/models/center_net_resnet_feature_extractor.py]
  • [research/object_detection/models/center_net_hourglass_feature_extractor.py]
  • [research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py]
  • [research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py]

RetinaNet

RetinaNet 是一种单阶段密集检测器,它使用特征金字塔网络 (FPN) 和 focal loss 来处理类别不平衡。

关键组件

  • 骨干网络:通常基于 ResNet
  • 特征金字塔网络 (FPN):组合不同尺度的特征
  • 分类子网络:预测类别概率
  • 边界框回归子网络:预测边界框的细化
  • Focal Loss:解决前景-背景类别不平衡问题

配置示例

来源

  • [official/vision/configs/retinanet.py:165-186]
  • [official/vision/tasks/retinanet.py]
  • [official/vision/dataloaders/retinanet_input.py]

目标分配

目标分配是训练过程中将预测框与真实框匹配的关键过程

来源

  • [research/object_detection/core/target_assigner.py:67-382]
  • [research/object_detection/core/target_assigner_test.py]
  • [research/object_detection/core/region_similarity_calculator.py]

图像处理与增强

预处理操作

vision 库包含一套全面的预处理操作,用于准备输入数据

关键预处理操作

  • 标准化:使用均值/标准差将图像转换为标准化范围
  • 调整大小:将图像缩放到所需尺寸
  • 裁剪:从图像中提取感兴趣区域
  • 填充:将图像填充到固定尺寸以便批处理

来源

  • [official/vision/ops/preprocess_ops.py:40-308]
  • [official/vision/ops/preprocess_ops_test.py]

数据增强

该库提供各种增强技术以提高模型泛化能力

增强技术包括

  • 几何变换:翻转、旋转、平移
  • 颜色变换:亮度、对比度、饱和度调整
  • 高级技巧:
    • AutoAugment(学习策略)
    • RandAugment(随机策略)
    • Mixup(图像混合)
    • Cutmix(块替换)

来源

  • [official/vision/ops/augment.py:15-34]
  • [official/vision/ops/augment_test.py]
  • [official/projects/yolo/ops/mosaic.py]

模型与训练流水线的集成

计算机视觉模型通过基于任务的标准化 API 与 TensorFlow 训练循环集成

模型集成的关键组件

  • 任务定义:封装模型、数据和训练配置
  • 输入流水线:用于数据集处理的标准化读取器和解析器
  • 模型架构:特征提取器和特定任务的头
  • 损失函数:视觉任务的专用损失(例如,focal loss)
  • 评估器:用于比较模型性能的标准指标

来源

  • [official/vision/tasks/retinanet.py:39-214]
  • [official/vision/dataloaders/retinanet_input.py:35-127]
  • [research/object_detection/builders/model_builder.py:396-495]

特征提取器

vision 模型支持多种特征提取器,可以进行切换以调整速度-准确性权衡

特征提取器参数速度准确率用例
ResNet-5025M中等通用
ResNet-10144M较高高精度
MobileNet V23.5M中等移动/边缘设备
EfficientNet可变可变可扩展性能
Hourglass可变关键点检测

来源

  • [research/object_detection/builders/model_builder.py:149-178]
  • [research/object_detection/models/center_net_resnet_feature_extractor.py]
  • [research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py]
  • [research/object_detection/models/center_net_hourglass_feature_extractor.py]

模型输出和后处理

目标检测模型产生需要后处理才能生成可用预测的原始输出

后处理操作

  • 边界框解码:将网络输出转换为绝对坐标
  • 置信度阈值处理:过滤低置信度预测
  • 非极大值抑制:去除重复检测
  • 附加输出:处理掩码、关键点和其他属性

来源

  • [research/object_detection/meta_architectures/center_net_meta_arch.py:370-427]
  • [official/vision/configs/retinanet.py:138-162]

语义分割

语义分割模型为图像中的每个像素预测一个类别

  • DeepLab:使用空洞卷积进行密集像素预测
  • 全景分割:结合语义分割和实例分割

实例分割扩展

一些检测模型(如 Mask R-CNN 和 DeepMAC)包含实例分割功能

来源

  • [research/object_detection/meta_architectures/deepmac_meta_arch.py:30-76]
  • [research/object_detection/meta_architectures/deepmac_meta_arch_test.py]

用法和示例

模型可以通过 research Object Detection API 或 official TensorFlow Model Garden API 使用

完整集成需要

  1. 使用适当的预处理设置数据流水线
  2. 配置模型架构和特征提取器
  3. 定义带有损失和优化的训练循环
  4. 设置评估和导出

来源

  • [research/object_detection/builders/model_builder.py:266-456]
  • [official/vision/tasks/retinanet.py:47-69]
  • [research/object_detection/packages/tf2/setup.py]