本页全面概述了 Keras 模型构建和训练功能。它涵盖了核心模型 API(函数式、序贯式、子类化)、模型编译、训练工作流以及相关主题。有关模型保存和加载的信息,请参阅保存和加载模型。
Keras 提供了灵活且强大的 API,用于构建和训练深度学习模型。其核心在于,Keras 模型被设计为可组合的,其中层是基本构建块。
来源
Keras 提供了三种主要的模型构建方式
函数式 API 将模型定义为层的有向无环图 (DAG)。这种方法允许构建具有多个输入、多个输出和共享层的复杂模型架构。
主要功能
使用示例
来源
当模型架构是真正的线性结构时,即层堆栈中每个层只有一个输入张量和一个输出张量时,序贯式 API 提供了一种简化的模型构建方式。
主要功能
.add() 逐层构建使用示例
来源
模型子类化通过在 call 方法中实现自定义前向传播,提供了最大的灵活性。
主要功能
使用示例
来源
Keras 模型的输入是使用 Input 类创建的,该类定义了输入张量的形状和数据类型。
关键输入属性
shape:输入张量的形状,不包括批次维度batch_size:可选的固定批次大小(通常保留为 None)dtype:输入的数据类型name:输入的可选名称来源
Keras 支持具有多种配置的多输入和多输出的复杂模型架构。
模型输入和输出可以结构化为
来源
模型编译是您通过指定以下内容来定义训练过程的地方
损失函数衡量模型表现的优劣。Keras 提供了多种指定损失的方式
| 模型输出类型 | 损失指定 | 示例 |
|---|---|---|
| 单输出 | 字符串或损失对象 | loss='mse' |
| 单输出 | 含一个元素的列表 | loss=['mse'] |
| 多输出(列表) | 损失列表 | loss=['mse', 'binary_crossentropy'] |
| 多输出(字典) | 将输出名称映射到损失的字典 | loss={'output_a': 'mse', 'output_b': 'bce'} |
CompileLoss 类处理将损失函数应用于各种输出结构(包括嵌套输出)的复杂性。它根据输出类型解析适当的损失,并根据指定的权重进行组合。
来源
指标衡量模型在训练和评估期间的性能,但不影响优化。它们的指定方式与损失类似
| 模型输出类型 | 指标指定 | 示例 |
|---|---|---|
| 单输出 | 指标列表 | metrics=['accuracy'] |
| 多输出(列表) | 指标列表的列表 | metrics=[['mae'], ['accuracy']] |
| 多输出(字典) | 将输出名称映射到指标列表的字典 | metrics={'output_a': ['mae'], 'output_b': ['accuracy']} |
CompileMetrics 类管理复杂输出结构中的指标,在训练期间保持正确的状态。
来源
对于多输出模型,您可以为不同的损失分量分配不同的权重
这允许控制每个输出对总损失的贡献程度。
来源
训练过程遵循以下通用工作流
Keras 模型可以接受各种格式的训练数据
| 数据格式 | 描述 | 兼容性 |
|---|---|---|
| NumPy 数组 | 直接以数组形式输入 | 所有后端 |
| TensorFlow 数据集 | tf.data.Dataset 对象 | TensorFlow 后端 |
| PyTorch DataLoader | PyTorch 数据加载器 | PyTorch 后端 |
| 字典输入 | 与输入名称匹配的关键字参数 | 所有模型类型 |
| 列表/元组输入 | 与模型输入匹配的位置参数 | 所有模型类型 |
来源
Keras 在训练期间为复杂数据结构提供了强大的支持
tree 模块在处理这些嵌套结构方面起着关键作用,它提供了诸如 flatten、map_structure 和 assert_same_structure 等函数来处理复杂数据类型。
来源
Keras 提供了在训练期间进行内存优化的机制,特别是对于大型模型而言
RematScope(重新物化范围)是一种通过丢弃中间激活并在反向传播期间需要时重新计算它们,从而以计算换取内存的技术。这对于训练大型模型或使用大批次大小特别有用。
来源
Keras 支持使用后端特定实现进行多 GPU 分布式训练
多 GPU 训练的方法因后端而异
每个模型副本处理不同的数据批次,梯度在设备间同步以更新全局模型。
来源
Keras 模型提供了多种用于检查和调试的方法
model.summary() 显示层信息、输出形状和参数数量model.layers 或 model.get_layer(name) 访问单个层model.weights 或 model.trainable_weights 查看模型权重model.input_shape 和 model.output_shape 访问形状这些工具提供了模型架构的透明度,并有助于在开发过程中识别潜在问题。
来源