菜单

模型架构

相关源文件

目的与范围

本文档详细介绍了 Whisper 自动语音识别 (ASR) Transformer 模型的架构。它涵盖了编码器-解码器结构、关键组件、注意力机制以及这些元素如何协同工作将语音音频转换为文本。有关模型接收音频之前的处理信息,请参阅 音频处理。有关解码输出如何处理的详细信息,请参阅 解码系统

概述

Whisper 是一个遵循编码器-解码器架构的序列到序列 Transformer 模型。该模型包含两个主要组件:

  1. 一个 音频编码器,用于处理梅尔频谱输入并生成音频特征
  2. 一个 文本解码器,用于从编码的音频特征生成文本

模型分两个阶段进行操作:首先将音频输入编码为潜在表示,然后将该表示解码为构成转录或翻译的文本标记。

来源: whisper/model.py252-307 whisper/model.py1-15

模型维度

模型的架构通过 ModelDimensions 数据类进行配置,该类为音频编码器和文本解码器组件指定了参数。

这些维度控制着模型的尺寸和能力。

参数描述
n_mels输入频谱中的梅尔频率带数量
n_audio_ctx音频编码器的最大上下文长度
n_audio_state编码器隐藏状态的维度
n_audio_head编码器中注意力头的数量
n_audio_layer编码器中 Transformer 层的数量
n_vocab文本解码器的词汇量大小
n_text_ctx文本解码器的最大上下文长度
n_text_state解码器隐藏状态的维度
n_text_head解码器中注意力头的数量
n_text_layer解码器中 Transformer 层的数量

Whisper 提供多种尺寸(tiny, base, small, medium, large),较大的模型拥有更多的层和参数。

来源: whisper/model.py25-36

音频编码器

AudioEncoder 将梅尔频谱转换为高维特征表示。它由以下部分组成:

  1. 初始卷积层(conv1conv2)用于下采样
  2. 添加的位置嵌入,以提供序列位置信息
  3. 一系列 ResidualAttentionBlock 层,用于上下文处理
  4. 最终的层归一化

编码器使用自注意力机制来处理音频特征的上下文,使其能够捕捉整个音频序列的模式。

来源: whisper/model.py174-204

文本解码器

TextDecoder 通过关注编码后的音频特征和先前生成的标记来生成文本标记。它由以下部分组成:

  1. 输入标记的标记和位置嵌入
  2. 一系列具有自注意力和交叉注意力的 ResidualAttentionBlock
  3. 最终的词汇对数(logits)投影

与编码器不同,解码器包含因果注意力掩码,以确保预测仅依赖于先前生成的标记。解码器还具有关注编码器生成的音频特征的交叉注意力层。

来源: whisper/model.py207-249

注意力机制

编码器和解码器的核心构建块是 ResidualAttentionBlock,它包含:

  1. 带层归一化和残差连接的自注意力
  2. 带层归一化和残差连接的交叉注意力(仅解码器)
  3. 带层归一化和残差连接的前馈网络

MultiHeadAttention 模块实现带有多头缩放点积注意力

  1. 输入被投影到查询、键和值表示
  2. 注意力使用缩放点积公式计算
  3. 注意力分数被掩码(用于解码器的因果注意力)
  4. 值根据注意力分数加权并投影到输出

模型支持标准注意力实现以及 PyTorch 优化的 scaled_dot_product_attention(如果可用)。

来源: whisper/model.py142-171 whisper/model.py81-139

KV 缓存

为了提高推理效率,Whisper 实现了一种键值缓存机制,该机制存储了注意力层先前计算的键值投影。这避免了逐个生成标记时的冗余计算。

缓存机制是通过对键值投影模块使用前向钩子来实现的,这些钩子会保存中间张量以供重用。

来源: whisper/model.py310-341

模型集成

Whisper 类集成了编码器和解码器组件。

Whisper 类提供了以下方法:

  • 通过 embed_audio 对音频进行编码器处理
  • 使用 logits 计算解码器对数
  • 通过 forward 进行端到端处理
  • 使用 detect_language 检测说话语言
  • 使用 transcribe 转录完整音频
  • 使用 decode 解码音频特征

这种设计在共享相同底层架构的同时,为模型用于不同任务提供了灵活性。

来源: whisper/model.py252-307

实现细节

Whisper 包含一些用于数值稳定性和性能的实现细节。

  1. 混合精度支持:自定义 LayerNormLinearConv1d 层,可尊重输入张量的数据类型。
  2. 优化注意力:支持 PyTorch 的优化 scaled_dot_product_attention(如果可用)。
  3. 多语言检测:根据词汇量大小标识模型是否为多语言的属性。
  4. 对齐头:使用特定的注意力头对时间对齐进行特殊处理。

此外,模型还包含在需要时禁用优化注意力实现的实用程序,以及根据词汇量大小确定模型是否为多语言的方法。

来源: whisper/model.py39-60 whisper/model.py15-22 whisper/model.py297-304