本页面介绍了labml_nn库中实现的高级Transformer架构,这些架构超越了标准编码器-解码器或自回归Transformer。这些专用变体引入了架构修改,以增强Transformer模型的特定功能。有关基本Transformer架构的信息,请参阅基本Transformer模型。
该仓库实现了几种专用的Transformer变体,每种都旨在解决标准Transformer的特定局限性
来源:labml_nn/transformers/switch/__init__.py1-39 labml_nn/transformers/feedback/__init__.py1-40 labml_nn/transformers/xl/readme.md1-24
Switch Transformer实现了一种条件计算形式,其中每个令牌在每个Transformer层中只路由到一个前馈网络(FFN)中的一个。
来源:labml_nn/transformers/switch/__init__.py48-165 labml_nn/transformers/switch/__init__.py168-212
SwitchFeedForward:该模块根据路由网络将令牌路由到不同的专家(FFN)。
self.switch)确定每个专家的概率SwitchTransformerLayer:通过用SwitchFeedForward替换标准FFN来扩展标准Transformer层。
SwitchTransformer:堆叠多个SwitchTransformerLayer,并带有最终的层归一化。
来源:labml_nn/transformers/switch/__init__.py48-165 labml_nn/transformers/switch/__init__.py168-212 labml_nn/transformers/switch/__init__.py215-239
Switch Transformer使用路由机制将每个令牌精确地定向到一个专家FFN
Router Network: Linear(d_model, n_experts) -> Softmax
Routing Decision: argmax(router_output)
每个专家都有一个基于批大小的容量限制
expert_capacity = capacity_factor * (tokens_per_batch / n_experts)
该实现包含通过附加损失项实现的负载均衡机制,以确保令牌在专家之间均匀分布
load_balancing_loss = n_experts * (route_frac * route_prob).sum()
其中
route_frac是路由到每个令牌的分数route_prob是每个专家的平均路由概率来源:labml_nn/transformers/switch/__init__.py113-115 labml_nn/transformers/switch/experiment.py124-140
Feedback Transformer顺序处理令牌,而非并行处理,允许每个位置关注之前步骤所有层的输出。
来源:labml_nn/transformers/feedback/__init__.py54-195 labml_nn/transformers/feedback/__init__.py198-249 labml_nn/transformers/feedback/__init__.py252-310 labml_nn/transformers/feedback/__init__.py444-529
FeedbackAttention:计算当前令牌与前一个令牌的记忆向量之间的注意力。
FeedbackTransformerLayer:通过使用FeedbackAttention并顺序处理令牌来扩展标准Transformer层。
FeedbackTransformer:原始实现,将所有层输出的加权和作为记忆。
FeedbackTransformerKV:优化实现,在层之间共享键和值的权重。
Stack实现以改进内存管理来源:labml_nn/transformers/feedback/__init__.py54-195 labml_nn/transformers/feedback/__init__.py198-249 labml_nn/transformers/feedback/__init__.py252-310 labml_nn/transformers/feedback/__init__.py444-529
Feedback Transformer使用一个记忆向量,它是所有层对先前令牌的输出的加权和
memory = Σ(softmax(weights) * layer_outputs)
注意力分数使用相对位置编码计算
Q = query + query_pos_bias
K = key
K_pos = key_pos_embeddings
Key_pos_bias = key_pos_bias
Attention = (Q·K + Q·K_pos + query_pos_bias·K + key_pos_bias)
该模型逐个处理令牌,这会减慢训练速度(5-10倍),但由于缓存,推理速度可能更快。
来源:labml_nn/transformers/feedback/__init__.py115-155 labml_nn/transformers/feedback/__init__.py303
Transformer XL通过允许每个位置关注前一个片段来扩展注意力范围,超越了标准固定长度上下文的限制。
来源:labml_nn/transformers/xl/readme.md1-24 labml_nn/transformers/relative_mha.py1-8
Transformer XL实现使用相对位置编码来区分不同片段的位置。它缓存来自先前片段的激活,从而在保持计算效率的同时实现更长的有效上下文长度。
| 变体 | 关键创新 | 优点 | 缺点 | 训练 | 推理 |
|---|---|---|---|---|---|
| Switch Transformer | 条件路由到专家FFN | 更多参数,相同计算量 | 负载均衡挑战 | 与标准模型相似 | 与标准模型相似 |
| Feedback Transformer | 带记忆的顺序处理 | 关注前一层输出 | 顺序处理较慢 | 慢5-10倍 | 缓存后更快 |
| Transformer XL | 之前片段的记忆 | 更长的注意力范围 | 更复杂的注意力 | 与标准模型相似 | 与标准模型相似 |
来源:labml_nn/transformers/switch/__init__.py9-37 labml_nn/transformers/feedback/__init__.py14-34 labml_nn/transformers/xl/readme.md7-17
所有专用Transformer变体都建立在核心Transformer模块之上,可以通过替换特定组件与标准Transformer架构集成
来源:labml_nn/transformers/switch/__init__.py43-46 labml_nn/transformers/feedback/__init__.py48-53 labml_nn/transformers/switch/experiment.py182-192 labml_nn/transformers/feedback/experiment.py76-83
创建Switch Transformer
SwitchFeedForward模块,指定专家数量SwitchTransformerLayerSwitchTransformer主要参数
n_experts:专家FFN的数量capacity_factor:控制专家容量drop_tokens:是否丢弃超出容量的令牌is_scale_prob:是否根据路由概率缩放输出来源:labml_nn/transformers/switch/experiment.py182-192
创建Feedback Transformer
FeedbackAttention模块FeedbackTransformerLayerFeedbackTransformer或FeedbackTransformerKVFeedbackTransformerKV变体应优先选择,以获得更好的性能,因为它会预计算和缓存键和值。
来源:labml_nn/transformers/feedback/experiment.py76-83 labml_nn/transformers/feedback/experiment.py95-102
Switch Transformer:在令牌在专家之间均衡分布时表现最佳。capacity_factor和drop_tokens参数可用于调整路由效率。
Feedback Transformer:由于顺序处理,训练速度比标准Transformer慢5-10倍。然而,由于缓存,推理速度可以更快。FeedbackTransformerKV变体通过共享键和值的权重显著提高了性能。
Transformer XL:增加了一些内存开销来存储以前的片段激活,但提供了更长的有效上下文长度。
来源:labml_nn/transformers/switch/__init__.py17-29 labml_nn/transformers/feedback/__init__.py14-23 labml_nn/transformers/xl/readme.md7-17
这些专用Transformer变体各自解决了标准Transformer架构的特定局限性,在计算效率、参数效率和建模能力之间提供了不同的权衡。