Trainer 类是 Hugging Face Transformers 库的核心组件,它为 PyTorch 模型提供了完整的训练和评估循环。通过处理常见的训练任务,如优化、梯度累积、分布式训练、混合精度、检查点保存和日志记录,它简化了 Transformer 模型的训练过程。
本文档介绍了 Trainer 类的架构、功能和用法。有关训练参数的信息,请参阅训练参数;有关使用回调函数自定义训练的信息,请参阅回调函数和扩展。
Trainer 类通过集成多个组件来协调整个训练过程
来源:src/transformers/trainer.py322-419 src/transformers/trainer.py459-496
Trainer 需要一个 PyTorch 模型,通常是 transformers 库中的 PreTrainedModel。该模型在训练时实现前向传播并计算损失。Trainer 支持完整的模型训练、微调,甚至通过 model_init 函数创建模型。
模型可以是
PreTrainedModel 实例传入model_init 函数创建来源:src/transformers/trainer.py482-526 src/transformers/trainer.py626-646
Trainer 通过 TrainingArguments 实例进行配置,该实例控制训练过程的所有方面
| 类别 | 参数示例 |
|---|---|
| 基础 | output_dir、num_train_epochs、max_steps |
| 优化 | learning_rate、weight_decay、adam_beta1、adam_beta2 |
| 批大小 | per_device_train_batch_size、per_device_eval_batch_size、gradient_accumulation_steps |
| 调度 | warmup_steps、warmup_ratio、lr_scheduler_type |
| 日志记录 | logging_dir、logging_steps、logging_strategy |
| 评估 | eval_strategy、eval_steps、metric_for_best_model |
| 保存 | save_strategy、save_steps、save_total_limit |
| 混合精度 | fp16、bf16、half_precision_backend |
| 分布式训练 | local_rank、deepspeed、fsdp |
来源:src/transformers/training_args.py209-1061 src/transformers/trainer.py439-443
Trainer 需要用于训练和评估的数据集
Dataset 对象、Hugging Face datasets 对象或任何提供训练/评估示例的可迭代对象来源:src/transformers/trainer.py604-618 src/transformers/trainer.py1778-1827 src/transformers/trainer.py1829-1870
Trainer 使用回调系统,允许在不修改核心训练循环的情况下自定义训练过程
默认回调函数包括
DefaultFlowCallback:控制训练流程ProgressCallback/PrinterCallback:显示训练进度来源:src/transformers/trainer_callback.py35-156 src/transformers/trainer.py683-689 src/transformers/trainer_callback.py317-361
当 Trainer 初始化时,它会
来源:src/transformers/trainer.py421-805
训练过程由 train() 方法处理,该方法
训练循环处理
来源:src/transformers/trainer.py2115-2339
Trainer 提供了评估模型和生成预测的方法
evaluate():在评估数据集上评估模型并报告指标predict():为数据集生成预测,不计算指标两种方法都使用类似的评估循环,该循环:
来源:src/transformers/trainer.py2578-2748 src/transformers/trainer.py2750-2834
有几种方法可以自定义训练过程
compute_loss 方法:覆盖 compute_loss 以实现自定义损失计算compute_loss_func 参数compute_metrics:指定一个 compute_metrics 函数来计算评估指标optimizers 参数传入自定义优化器和调度器来源:src/transformers/trainer.py3613-3693 src/transformers/trainer.py369-372
Trainer 支持多种分布式训练方法
| 方法 | 配置 |
|---|---|
| PyTorch DDP | 设置 local_rank 或使用 torchrun |
| DeepSpeed | 提供 deepspeed 配置文件/字典 |
| 完全分片数据并行 (FSDP) | 设置带选项的 fsdp 参数 |
| XLA (TPU) | 在 TPU 上运行时自动检测 |
来源:src/transformers/trainer.py595-602 src/transformers/training_args.py485-557
Trainer 支持各种精度优化技术
fp16=True 启用bf16=True 启用来源:src/transformers/trainer.py747-773 src/transformers/training_args.py391-415
Trainer 实现了多种优化以提高训练性能
auto_find_batch_size=True 自动找到最大的批大小torch_empty_cache_steps 等选项有助于解决 CUDA OOM 错误batch_eval_metrics=True 进行批处理指标计算来源:src/transformers/trainer.py443-457 src/transformers/trainer.py747-773 src/transformers/training_args.py268-274
Trainer 类为使用 PyTorch 训练 Transformer 模型提供了全面的解决方案。其模块化架构和广泛的自定义选项使其适用于各种训练场景,从简单的微调任务到复杂的分布式训练工作负载。
通过处理训练循环、优化器和评估的常见样板代码,Trainer 使研究人员和实践者能够专注于模型开发,而不是训练基础设施。