计算机视觉模型
相关源文件
本文档全面概述了 TensorFlow Model Garden 中提供的计算机视觉模型和组件。它涵盖了目标检测架构、图像分类骨干网络、语义分割模型以及支持这些模型的共享预处理基础设施。
有关 NLP 模型的信息,请参阅 NLP 模型,有关推荐系统的信息,请参阅 推荐模型。
架构概述
TensorFlow Model Garden 中的计算机视觉生态系统包含多个相互关联的组件
来源
- [research/object_detection/meta_architectures/center_net_meta_arch.py]
- [research/object_detection/builders/model_builder.py]
- [official/vision/ops/augment.py]
- [official/vision/ops/preprocess_ops.py]
目标检测模型
目标检测模型识别和定位图像中的对象,提供类别标签和边界框坐标。
模型架构比较
来源
- [research/object_detection/meta_architectures/center_net_meta_arch.py]
- [official/vision/configs/retinanet.py]
- [research/object_detection/builders/model_builder.py]
CenterNet
CenterNet 是一种基于关键点的目标检测架构,在“Objects as Points”论文中有描述。它预测对象的中心点及其属性(大小、偏移量等)。
关键组件
- 特征提取器:多种选项,包括 Hourglass、ResNet 和 MobileNet
- 预测头:生成对象中心、对象大小和偏移量的热图
- 后处理:将热图峰值转换为对象检测结果
CenterNet 实现了一种目标分配策略,在对象中心创建高斯峰值
支持的特征提取器
- Hourglass(10/20/32/52/104 层变体)
- ResNet v1 和 v2
- MobileNet v2(带 FPN 和不带 FPN)
来源
- [research/object_detection/meta_architectures/center_net_meta_arch.py:15-43]
- [research/object_detection/models/center_net_resnet_feature_extractor.py]
- [research/object_detection/models/center_net_hourglass_feature_extractor.py]
- [research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py]
- [research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py]
RetinaNet
RetinaNet 是一种单阶段密集检测器,它使用特征金字塔网络 (FPN) 和 focal loss 来处理类别不平衡。
关键组件
- 骨干网络:通常基于 ResNet
- 特征金字塔网络 (FPN):组合不同尺度的特征
- 分类子网络:预测类别概率
- 边界框回归子网络:预测边界框的细化
- Focal Loss:解决前景-背景类别不平衡问题
配置示例
来源
- [official/vision/configs/retinanet.py:165-186]
- [official/vision/tasks/retinanet.py]
- [official/vision/dataloaders/retinanet_input.py]
目标分配
目标分配是训练过程中将预测框与真实框匹配的关键过程
来源
- [research/object_detection/core/target_assigner.py:67-382]
- [research/object_detection/core/target_assigner_test.py]
- [research/object_detection/core/region_similarity_calculator.py]
图像处理与增强
预处理操作
vision 库包含一套全面的预处理操作,用于准备输入数据
关键预处理操作
- 标准化:使用均值/标准差将图像转换为标准化范围
- 调整大小:将图像缩放到所需尺寸
- 裁剪:从图像中提取感兴趣区域
- 填充:将图像填充到固定尺寸以便批处理
来源
- [official/vision/ops/preprocess_ops.py:40-308]
- [official/vision/ops/preprocess_ops_test.py]
数据增强
该库提供各种增强技术以提高模型泛化能力
增强技术包括
- 几何变换:翻转、旋转、平移
- 颜色变换:亮度、对比度、饱和度调整
- 高级技巧:
- AutoAugment(学习策略)
- RandAugment(随机策略)
- Mixup(图像混合)
- Cutmix(块替换)
来源
- [official/vision/ops/augment.py:15-34]
- [official/vision/ops/augment_test.py]
- [official/projects/yolo/ops/mosaic.py]
模型与训练流水线的集成
计算机视觉模型通过基于任务的标准化 API 与 TensorFlow 训练循环集成
模型集成的关键组件
- 任务定义:封装模型、数据和训练配置
- 输入流水线:用于数据集处理的标准化读取器和解析器
- 模型架构:特征提取器和特定任务的头
- 损失函数:视觉任务的专用损失(例如,focal loss)
- 评估器:用于比较模型性能的标准指标
来源
- [official/vision/tasks/retinanet.py:39-214]
- [official/vision/dataloaders/retinanet_input.py:35-127]
- [research/object_detection/builders/model_builder.py:396-495]
vision 模型支持多种特征提取器,可以进行切换以调整速度-准确性权衡
| 特征提取器 | 参数 | 速度 | 准确率 | 用例 |
|---|
| ResNet-50 | 25M | 中等 | 高 | 通用 |
| ResNet-101 | 44M | 慢 | 较高 | 高精度 |
| MobileNet V2 | 3.5M | 快 | 中等 | 移动/边缘设备 |
| EfficientNet | 可变 | 可变 | 高 | 可扩展性能 |
| Hourglass | 可变 | 慢 | 高 | 关键点检测 |
来源
- [research/object_detection/builders/model_builder.py:149-178]
- [research/object_detection/models/center_net_resnet_feature_extractor.py]
- [research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py]
- [research/object_detection/models/center_net_hourglass_feature_extractor.py]
模型输出和后处理
目标检测模型产生需要后处理才能生成可用预测的原始输出
后处理操作
- 边界框解码:将网络输出转换为绝对坐标
- 置信度阈值处理:过滤低置信度预测
- 非极大值抑制:去除重复检测
- 附加输出:处理掩码、关键点和其他属性
来源
- [research/object_detection/meta_architectures/center_net_meta_arch.py:370-427]
- [official/vision/configs/retinanet.py:138-162]
语义分割
语义分割模型为图像中的每个像素预测一个类别
- DeepLab:使用空洞卷积进行密集像素预测
- 全景分割:结合语义分割和实例分割
实例分割扩展
一些检测模型(如 Mask R-CNN 和 DeepMAC)包含实例分割功能
来源
- [research/object_detection/meta_architectures/deepmac_meta_arch.py:30-76]
- [research/object_detection/meta_architectures/deepmac_meta_arch_test.py]
用法和示例
模型可以通过 research Object Detection API 或 official TensorFlow Model Garden API 使用
完整集成需要
- 使用适当的预处理设置数据流水线
- 配置模型架构和特征提取器
- 定义带有损失和优化的训练循环
- 设置评估和导出
来源
- [research/object_detection/builders/model_builder.py:266-456]
- [official/vision/tasks/retinanet.py:47-69]
- [research/object_detection/packages/tf2/setup.py]