菜单

模型与训练

相关源文件

本页全面概述了 Keras 模型构建和训练功能。它涵盖了核心模型 API(函数式、序贯式、子类化)、模型编译、训练工作流以及相关主题。有关模型保存和加载的信息,请参阅保存和加载模型

Keras 中的模型架构概述

Keras 提供了灵活且强大的 API,用于构建和训练深度学习模型。其核心在于,Keras 模型被设计为可组合的,其中层是基本构建块。

来源

模型类型

Keras 提供了三种主要的模型构建方式

函数式 API

函数式 API 将模型定义为层的有向无环图 (DAG)。这种方法允许构建具有多个输入、多个输出和共享层的复杂模型架构。

主要功能

  • 在模型创建时定义明确的输入形状
  • 支持多输入和多输出
  • 支持复杂的非序贯层连接
  • 允许层共享和模型合并

使用示例

来源

序贯式 API

当模型架构是真正的线性结构时,即层堆栈中每个层只有一个输入张量和一个输出张量时,序贯式 API 提供了一种简化的模型构建方式。

主要功能

  • 适用于线性层堆栈的更简单 API
  • 可以通过 .add() 逐层构建
  • 内部转换为函数式模型
  • 支持即时执行和符号模式

使用示例

来源

子类化模型

模型子类化通过在 call 方法中实现自定义前向传播,提供了最大的灵活性。

主要功能

  • 自定义行为的最大灵活性
  • 动态计算图
  • 不易序列化(保存/加载能力有限)
  • 更命令式/类 PyTorch 的方法

使用示例

来源

模型构建

输入创建

Keras 模型的输入是使用 Input 类创建的,该类定义了输入张量的形状和数据类型。

关键输入属性

  • shape:输入张量的形状,不包括批次维度
  • batch_size:可选的固定批次大小(通常保留为 None
  • dtype:输入的数据类型
  • name:输入的可选名称

来源

多输入/多输出模型

Keras 支持具有多种配置的多输入和多输出的复杂模型架构。

模型输入和输出可以结构化为

  • 单个张量
  • 张量列表/元组
  • 将名称映射到张量的字典
  • 上述的嵌套结构

来源

模型编译

模型编译是您通过指定以下内容来定义训练过程的地方

  1. 损失函数
  2. 优化器
  3. 评估指标

损失函数

损失函数衡量模型表现的优劣。Keras 提供了多种指定损失的方式

模型输出类型损失指定示例
单输出字符串或损失对象loss='mse'
单输出含一个元素的列表loss=['mse']
多输出(列表)损失列表loss=['mse', 'binary_crossentropy']
多输出(字典)将输出名称映射到损失的字典loss={'output_a': 'mse', 'output_b': 'bce'}

CompileLoss 类处理将损失函数应用于各种输出结构(包括嵌套输出)的复杂性。它根据输出类型解析适当的损失,并根据指定的权重进行组合。

来源

评估指标

指标衡量模型在训练和评估期间的性能,但不影响优化。它们的指定方式与损失类似

模型输出类型指标指定示例
单输出指标列表metrics=['accuracy']
多输出(列表)指标列表的列表metrics=[['mae'], ['accuracy']]
多输出(字典)将输出名称映射到指标列表的字典metrics={'output_a': ['mae'], 'output_b': ['accuracy']}

CompileMetrics 类管理复杂输出结构中的指标,在训练期间保持正确的状态。

来源

损失权重

对于多输出模型,您可以为不同的损失分量分配不同的权重

这允许控制每个输出对总损失的贡献程度。

来源

模型训练

训练工作流程

训练过程遵循以下通用工作流

处理不同输入类型

Keras 模型可以接受各种格式的训练数据

数据格式描述兼容性
NumPy 数组直接以数组形式输入所有后端
TensorFlow 数据集tf.data.Dataset 对象TensorFlow 后端
PyTorch DataLoaderPyTorch 数据加载器PyTorch 后端
字典输入与输入名称匹配的关键字参数所有模型类型
列表/元组输入与模型输入匹配的位置参数所有模型类型

来源

使用字典和嵌套结构进行训练

Keras 在训练期间为复杂数据结构提供了强大的支持

tree 模块在处理这些嵌套结构方面起着关键作用,它提供了诸如 flattenmap_structureassert_same_structure 等函数来处理复杂数据类型。

来源

内存优化

Keras 提供了在训练期间进行内存优化的机制,特别是对于大型模型而言

RematScope(重新物化范围)是一种通过丢弃中间激活并在反向传播期间需要时重新计算它们,从而以计算换取内存的技术。这对于训练大型模型或使用大批次大小特别有用。

来源

多GPU训练

Keras 支持使用后端特定实现进行多 GPU 分布式训练

多 GPU 训练的方法因后端而异

  • PyTorch:使用 DistributedDataParallel 在设备间复制模型
  • JAX:使用 sharding 来分区数据和模型参数
  • TensorFlow:使用分布式策略来实现各种并行方法

每个模型副本处理不同的数据批次,梯度在设备间同步以更新全局模型。

来源

调试和模型信息

Keras 模型提供了多种用于检查和调试的方法

  1. 模型摘要model.summary() 显示层信息、输出形状和参数数量
  2. 层访问:通过 model.layersmodel.get_layer(name) 访问单个层
  3. 权重检查:使用 model.weightsmodel.trainable_weights 查看模型权重
  4. 输入/输出形状:通过 model.input_shapemodel.output_shape 访问形状

这些工具提供了模型架构的透明度,并有助于在开发过程中识别潜在问题。

来源