此页面提供了 Whisper 模型架构及其核心组件的全面概述。它解释了编码器-解码器设计、注意力机制,以及不同组件如何协同工作将语音转换为文本。有关使用这些组件的信息,请参阅 入门指南,有关实现细节,请参阅 实现细节。
Whisper 采用基于 Transformer 的编码器-解码器架构,针对语音识别和翻译任务进行了优化。该模型通过编码器处理音频以创建潜在表示,然后由解码器将其转换为文本。
该架构由 ModelDimensions 数据类参数化,该数据类定义了模型的规模和能力
AudioEncoder 通过卷积降采样后接 Transformer 块,将梅尔语谱图处理成一系列音频特征
TextDecoder 通过使用先前生成的 token 的自注意力机制和对编码器音频特征的交叉注意力机制来生成文本
Transformer 架构的核心在于其注意力机制,通过 MultiHeadAttention 类实现
注意力机制支持自注意力(在编码器和解码器中)和交叉注意力(仅在解码器中)。它还实现了优化,例如在可用时使用 PyTorch 的 scaled_dot_product_attention。
编码器和解码器均由 ResidualAttentionBlock 模块组成,这些模块将注意力机制和前馈网络与残差连接相结合
在编码器中,仅使用自注意力,而在解码器中,同时使用自注意力(带因果掩码)和对编码器输出的交叉注意力。
Whisper 有多种尺寸可供选择,具有不同的参数数量
| 大小 | 参数 | 仅限英语模型 | 多语言模型 |
|---|---|---|---|
| tiny | 39 M | ✓ | ✓ |
| base | 74 M | ✓ | ✓ |
| small | 244 M | ✓ | ✓ |
| medium | 769 M | ✓ | ✓ |
| large | 1550 M | ✓ | |
| turbo | 798 M | ✓ |
随着时间的推移,已经发布了改进版本
large-v2 (2022 年 12 月)large-v3 (2023 年 11 月)turbo (2024 年 9 月) - 针对推理速度进行了优化为了在自回归解码过程中实现高效推理,Whisper 实现了键值缓存
这是通过 PyTorch 钩子实现的,这些钩子在正向传播过程中截获并缓存键/值投影层的输出。
Whisper 包含标准 PyTorch 层的自定义实现,以更好地支持混合精度运算
LayerNorm:以 float32 进行归一化,但以输入数据类型返回Linear:处理输入和权重之间不同的数据类型Conv1d:管理卷积运算的数据类型转换整个架构可以看作是这些组件的组合
Whisper 类将这些组件绑定在一起,并提供高级方法来编码音频(embed_audio)、从 token 和音频特征生成 logits(logits),以及连接到其他模块功能的便捷方法(detect_language、transcribe 和 decode)。