菜单

神经网络操作

相关源文件

本页面介绍了由 keras.ops.nn 模块提供的神经网络特有操作,该模块在 Keras 3 中提供了一套全面的与后端无关的神经网络操作。有关通用 NumPy 兼容操作的信息,请参阅 NumPy 兼容操作

概述

神经网络操作是构建和训练深度学习模型必不可少的专用函数。keras.ops.nn 模块为所有支持的后端(TensorFlow、JAX、PyTorch、NumPy 和 OpenVINO)的这些操作提供了统一的 API。这些操作符合 Keras 3 的多后端架构,如下所示

来源

操作类别

keras.ops.nn 模块提供了以下类别的操作

来源

激活函数

激活函数为神经网络引入非线性。Keras 提供了多种激活函数

激活函数描述定义
relu修正线性单元f(x) = max(0, x)
relu6上限为 6 的 ReLUf(x) = min(max(0, x), 6)
sigmoidSigmoid 激活函数f(x) = 1 / (1 + exp(-x))
tanh双曲正切f(x) = tanh(x)
softmaxSoftmax 归一化f(x) = exp(x) / sum(exp(x))
log_softmaxSoftmax 的对数f(x) = log(softmax(x))
silu / swishSigmoid 线性单元f(x) = x * sigmoid(x)
gelu高斯误差线性单元f(x) = x * P(X <= x), where P(X) ~ N(0,1)
elu指数线性单元f(x) = x if x > 0 else alpha * (exp(x) - 1)
selu缩放指数线性单元ELU 的缩放版本
leaky_reluLeaky ReLUf(x) = x if x > 0 else alpha * x
hard_sigmoid硬 SigmoidSigmoid 的分段线性近似
hard_silu / hard_swish硬 SiLU/SwishSiLU 的分段线性近似
softplusSoftplus 激活函数f(x) = log(exp(x) + 1)
softsignSoftsign 激活函数f(x) = x / (abs(x) + 1)

高级激活函数包括

激活函数描述
celu连续可微分 ELU
glu门控线性单元
sparsemaxSoftmax 的稀疏替代方案
log_sigmoidSigmoid 的对数
sparse_sigmoid分段线性 Sigmoid
sparse_plus具有稀疏特性的替代激活函数
squareplusReLU 的平滑近似
hard_shrink硬收缩函数
soft_shrink软收缩函数
tanh_shrinkTanh 收缩函数
hard_tanh硬 tanh 激活函数
threshold简单阈值函数

来源

操作结构

每个激活函数都遵循相似的结构,包括一个实现操作的类和一个公开操作的函数

来源

relu 函数的示例实现

来源

池化操作

池化操作通过对局部区域应用函数来减小空间维度

操作描述
max_pool最大值池化
average_pool平均值池化

这些操作支持任意维度输入(1D、2D、3D),并处理 channels_firstchannels_last 两种数据格式。

来源

卷积操作

卷积操作是神经网络处理空间数据的基本组成部分

操作描述
conv标准卷积操作
depthwise_conv深度卷积(对每个输入通道单独应用滤波器)
separable_conv可分离卷积(深度卷积 + 逐点卷积)
conv_transpose用于上采样的转置卷积

这些操作支持可变维度(1D、2D、3D)、步长、填充、扩张,以及 channels_firstchannels_last 两种数据格式。

来源

归一化操作

归一化操作有助于稳定和加速训练

操作描述
batch_normalization根据批次统计数据对激活进行归一化
rms_normalization均方根归一化
normalize沿指定轴对张量进行归一化

来源

损失函数

损失函数计算预测值和目标值之间的差异

损失函数描述
binary_crossentropy用于二元分类问题
categorical_crossentropy用于具有独热编码目标的多元分类
sparse_categorical_crossentropy用于具有整数目标的多元分类
ctc_loss用于序列问题的连接主义时间分类损失

来源

编码操作

编码操作将类别数据转换为适合神经网络的格式

操作描述
one_hot将整数转换为独热编码向量
multi_hot将多个整数转换为单个多热编码向量

这两种操作都支持可选的稀疏张量输出,以提高内存效率。

来源

其他实用操作

操作描述
moments计算张量的均值和方差
dot_product_attention实现 Transformer 模型的注意力机制
ctc_decode从基于 CTC 的模型解码输出
psnr计算图像质量的峰值信噪比
polar将复数值表示转换为极坐标

来源

操作流程和后端调度

当你调用诸如 keras.ops.nn.relu(x) 这样的神经网络操作时,会发生以下过程

来源

使用示例

基本激活

卷积操作

模型中的组合操作

神经网络操作可以直接用于自定义层和模型中,提供了一个灵活的底层 API,与 Keras 的高层 API 相辅相成。