菜单

Trainer 类

相关源文件

Trainer 类是 Hugging Face Transformers 库的核心组件,它为 PyTorch 模型提供了完整的训练和评估循环。通过处理常见的训练任务,如优化、梯度累积、分布式训练、混合精度、检查点保存和日志记录,它简化了 Transformer 模型的训练过程。

本文档介绍了 Trainer 类的架构、功能和用法。有关训练参数的信息,请参阅训练参数;有关使用回调函数自定义训练的信息,请参阅回调函数和扩展

架构概述

Trainer 类通过集成多个组件来协调整个训练过程

来源:src/transformers/trainer.py322-419 src/transformers/trainer.py459-496

核心组件

模型和配置

Trainer 需要一个 PyTorch 模型,通常是 transformers 库中的 PreTrainedModel。该模型在训练时实现前向传播并计算损失。Trainer 支持完整的模型训练、微调,甚至通过 model_init 函数创建模型。

模型可以是

  • 直接作为 PreTrainedModel 实例传入
  • 通过 model_init 函数创建
  • 封装在各种并行处理封装器中,例如 DeepSpeed 或 FSDP

来源:src/transformers/trainer.py482-526 src/transformers/trainer.py626-646

训练参数

Trainer 通过 TrainingArguments 实例进行配置,该实例控制训练过程的所有方面

类别参数示例
基础output_dirnum_train_epochsmax_steps
优化learning_rateweight_decayadam_beta1adam_beta2
批大小per_device_train_batch_sizeper_device_eval_batch_sizegradient_accumulation_steps
调度warmup_stepswarmup_ratiolr_scheduler_type
日志记录logging_dirlogging_stepslogging_strategy
评估eval_strategyeval_stepsmetric_for_best_model
保存save_strategysave_stepssave_total_limit
混合精度fp16bf16half_precision_backend
分布式训练local_rankdeepspeedfsdp

来源:src/transformers/training_args.py209-1061 src/transformers/trainer.py439-443

数据集和数据处理

Trainer 需要用于训练和评估的数据集

  • 数据集:训练器接受 PyTorch Dataset 对象、Hugging Face datasets 对象或任何提供训练/评估示例的可迭代对象
  • 数据整理器:将单个示例处理成批次,处理填充和其他预处理
  • 分词器/处理类:用于处理文本或其他输入的可选组件

来源:src/transformers/trainer.py604-618 src/transformers/trainer.py1778-1827 src/transformers/trainer.py1829-1870

回调函数

Trainer 使用回调系统,允许在不修改核心训练循环的情况下自定义训练过程

默认回调函数包括

  • DefaultFlowCallback:控制训练流程
  • ProgressCallback/PrinterCallback:显示训练进度
  • 各种日志服务集成回调函数

来源:src/transformers/trainer_callback.py35-156 src/transformers/trainer.py683-689 src/transformers/trainer_callback.py317-361

训练生命周期

初始化

Trainer 初始化时,它会

  1. 验证所有必需组件是否存在
  2. 在适当的设备上设置模型
  3. 如果未提供,则准备优化器和调度器
  4. 设置回调处理程序
  5. 初始化内存跟踪和日志记录

来源:src/transformers/trainer.py421-805

训练过程

训练过程由 train() 方法处理,该方法

训练循环处理

  • 批处理和设备放置
  • 前向和后向传播
  • 梯度累积
  • 梯度裁剪
  • 优化器步进
  • 学习率调度
  • 混合精度训练
  • 分布式训练

来源:src/transformers/trainer.py2115-2339

评估与预测

Trainer 提供了评估模型和生成预测的方法

  • evaluate():在评估数据集上评估模型并报告指标
  • predict():为数据集生成预测,不计算指标

两种方法都使用类似的评估循环,该循环:

  1. 禁用梯度计算
  2. 将模型设置为评估模式
  3. 迭代评估数据集
  4. 收集预测并可选地计算指标

来源:src/transformers/trainer.py2578-2748 src/transformers/trainer.py2750-2834

自定义选项

自定义训练循环

有几种方法可以自定义训练过程

  1. 自定义 compute_loss 方法:覆盖 compute_loss 以实现自定义损失计算
  2. 损失计算函数:在初始化时传入 compute_loss_func 参数
  3. 自定义 compute_metrics:指定一个 compute_metrics 函数来计算评估指标
  4. 自定义优化器:通过 optimizers 参数传入自定义优化器和调度器

来源:src/transformers/trainer.py3613-3693 src/transformers/trainer.py369-372

分布式训练

Trainer 支持多种分布式训练方法

方法配置
PyTorch DDP设置 local_rank 或使用 torchrun
DeepSpeed提供 deepspeed 配置文件/字典
完全分片数据并行 (FSDP)设置带选项的 fsdp 参数
XLA (TPU)在 TPU 上运行时自动检测

来源:src/transformers/trainer.py595-602 src/transformers/training_args.py485-557

混合精度训练

Trainer 支持各种精度优化技术

  • FP16:在训练参数中通过 fp16=True 启用
  • BF16:在训练参数中通过 bf16=True 启用
  • TF32:在 Ampere+ NVIDIA GPU 上可用
  • 混合精度后端:APEX、原生 AMP、CPU AMP

来源:src/transformers/trainer.py747-773 src/transformers/training_args.py391-415

常见使用模式

基本训练示例

使用自定义组件进行训练

性能考量

Trainer 实现了多种优化以提高训练性能

  1. 自动查找批大小:通过 auto_find_batch_size=True 自动找到最大的批大小
  2. 梯度累积:允许在有限的硬件上模拟更大的批大小
  3. 内存优化:诸如 torch_empty_cache_steps 等选项有助于解决 CUDA OOM 错误
  4. 梯度检查点:以计算换取内存节省
  5. 高效评估:支持通过 batch_eval_metrics=True 进行批处理指标计算

来源:src/transformers/trainer.py443-457 src/transformers/trainer.py747-773 src/transformers/training_args.py268-274

结论

Trainer 类为使用 PyTorch 训练 Transformer 模型提供了全面的解决方案。其模块化架构和广泛的自定义选项使其适用于各种训练场景,从简单的微调任务到复杂的分布式训练工作负载。

通过处理训练循环、优化器和评估的常见样板代码,Trainer 使研究人员和实践者能够专注于模型开发,而不是训练基础设施。