本页面详细介绍了 Keras 3 中模型的编译、训练、评估和预测的使用方法。它涵盖了将构建好的模型转化为能够进行预测的已训练模型 Thus,我们在这里将这些模型进行训练,以便可以生成预测。如果您想了解有关模型构建的信息,请参阅模型构建 API。
在 Keras 中训练模型之前,您需要通过指定优化器、损失函数和评估指标来编译它。编译通过设置训练状态和计算图来配置模型以进行训练。
来源:keras/src/trainers/compile_utils.py98-121 keras/src/trainers/compile_utils.py124-410
compile方法接受几个关键参数
optimizer:训练期间使用的优化算法loss:训练期间要最小化的损失函数metrics:训练和评估期间要监控的评估指标loss_weights:多输出模型中不同输出的可选权重weighted_metrics:在训练期间使用样本权重计算的评估指标在编译过程中,Keras
CompileLoss和CompileMetrics实例当指定了损失和评估指标时,Keras 支持多种格式,具体取决于模型的输出结构
'mse'、'categorical_crossentropy')来源:keras/src/trainers/compile_utils.py412-819 keras/src/models/functional_test.py682-696
Keras 可以根据输出和目标的形状自动选择合适的损失函数
来源:keras/src/trainers/compile_utils.py48-58 keras/src/trainers/compile_utils.py95-121
训练工作流由fit方法处理,该方法以批处理方式处理数据、计算损失、更新权重并跟踪评估指标。
来源:keras/src/models/model_test.py683-736 keras/src/models/functional_test.py682-696
Keras 模型可以处理各种输入结构
在训练过程中,输入会被标准化以匹配模型期望的输入形状和类型。这种标准化包括
来源:keras/src/models/functional.py286-322 keras/src/models/functional.py240-285
Keras 为多输出模型提供了广泛的支持,允许为每个输出使用不同的损失函数和评估指标。
损失函数和评估指标的指定方式有以下几种
model.compile(loss=['mse', 'binary_crossentropy'])model.compile(loss={'output_a': 'mse', 'output_b': 'binary_crossentropy'})当使用字典时,输出名称可以与相应的损失函数和评估指标进行匹配。
来源:keras/src/models/model_test.py407-486 keras/src/models/functional_test.py682-736
评估工作流由evaluate方法处理,该方法在验证集或测试集上计算评估指标,而不更新模型权重。
在评估过程中,Keras
来源:keras/src/models/model_test.py722-729 keras/src/trainers/compile_utils.py358-402
predict方法在不计算损失或更新评估指标的情况下,对新数据生成预测。
predict的输出格式与模型输出的结构相匹配
来源:keras/src/models/model_test.py730-736 keras/src/models/functional_test.py96-125
两个关键类处理训练和评估期间的损失和评估指标计算
CompileLoss 处理损失函数的配置和计算,支持
loss_weights的加权损失此类负责构建和应用损失函数到不同输出的过程,聚合结果,并在训练期间跟踪各个损失值。
来源:keras/src/trainers/compile_utils.py412-819 keras/src/trainers/compile_utils_test.py239-488
CompileMetrics 处理评估指标的配置和计算,支持
此类负责管理训练和评估期间的评估指标状态更新,并为所有正在监控的评估指标提供聚合结果。
来源:keras/src/trainers/compile_utils.py124-410 keras/src/trainers/compile_utils_test.py16-237
Keras 支持不同模型类型的训练,每种模型都有其特定的特性
函数式模型支持具有多个输入和输出的复杂架构,包括字典输入/输出。它们具有明确定义的层有向无环图 (DAG)。
来源:keras/src/models/functional.py27-147 keras/src/models/functional_test.py407-541
Sequential 模型代表一个线性的层堆栈。它们更易于使用,但仅限于单输入、单输出架构。
来源: keras/src/models/sequential.py20-67 keras/src/models/sequential_test.py16-67
无论模型类型如何,训练流程都遵循以下步骤:
来源: keras/src/models/model_test.py682-736
Keras 3 支持多种后端(TensorFlow、JAX、PyTorch)进行训练,并提供抽象化的训练逻辑,可在不同后端上运行。
训练系统在内部处理特定于后端的实现细节,提供与所用后端无关的统一 API。这包括: