本页面介绍了由 keras.ops.nn 模块提供的神经网络特有操作,该模块在 Keras 3 中提供了一套全面的与后端无关的神经网络操作。有关通用 NumPy 兼容操作的信息,请参阅 NumPy 兼容操作。
神经网络操作是构建和训练深度学习模型必不可少的专用函数。keras.ops.nn 模块为所有支持的后端(TensorFlow、JAX、PyTorch、NumPy 和 OpenVINO)的这些操作提供了统一的 API。这些操作符合 Keras 3 的多后端架构,如下所示
来源
keras.ops.nn 模块提供了以下类别的操作
来源
激活函数为神经网络引入非线性。Keras 提供了多种激活函数
| 激活函数 | 描述 | 定义 |
|---|---|---|
relu | 修正线性单元 | f(x) = max(0, x) |
relu6 | 上限为 6 的 ReLU | f(x) = min(max(0, x), 6) |
sigmoid | Sigmoid 激活函数 | f(x) = 1 / (1 + exp(-x)) |
tanh | 双曲正切 | f(x) = tanh(x) |
softmax | Softmax 归一化 | f(x) = exp(x) / sum(exp(x)) |
log_softmax | Softmax 的对数 | f(x) = log(softmax(x)) |
silu / swish | Sigmoid 线性单元 | f(x) = x * sigmoid(x) |
gelu | 高斯误差线性单元 | f(x) = x * P(X <= x), where P(X) ~ N(0,1) |
elu | 指数线性单元 | f(x) = x if x > 0 else alpha * (exp(x) - 1) |
selu | 缩放指数线性单元 | ELU 的缩放版本 |
leaky_relu | Leaky ReLU | f(x) = x if x > 0 else alpha * x |
hard_sigmoid | 硬 Sigmoid | Sigmoid 的分段线性近似 |
hard_silu / hard_swish | 硬 SiLU/Swish | SiLU 的分段线性近似 |
softplus | Softplus 激活函数 | f(x) = log(exp(x) + 1) |
softsign | Softsign 激活函数 | f(x) = x / (abs(x) + 1) |
高级激活函数包括
| 激活函数 | 描述 |
|---|---|
celu | 连续可微分 ELU |
glu | 门控线性单元 |
sparsemax | Softmax 的稀疏替代方案 |
log_sigmoid | Sigmoid 的对数 |
sparse_sigmoid | 分段线性 Sigmoid |
sparse_plus | 具有稀疏特性的替代激活函数 |
squareplus | ReLU 的平滑近似 |
hard_shrink | 硬收缩函数 |
soft_shrink | 软收缩函数 |
tanh_shrink | Tanh 收缩函数 |
hard_tanh | 硬 tanh 激活函数 |
threshold | 简单阈值函数 |
来源
每个激活函数都遵循相似的结构,包括一个实现操作的类和一个公开操作的函数
来源
relu 函数的示例实现
来源
池化操作通过对局部区域应用函数来减小空间维度
| 操作 | 描述 |
|---|---|
max_pool | 最大值池化 |
average_pool | 平均值池化 |
这些操作支持任意维度输入(1D、2D、3D),并处理 channels_first 和 channels_last 两种数据格式。
来源
卷积操作是神经网络处理空间数据的基本组成部分
| 操作 | 描述 |
|---|---|
conv | 标准卷积操作 |
depthwise_conv | 深度卷积(对每个输入通道单独应用滤波器) |
separable_conv | 可分离卷积(深度卷积 + 逐点卷积) |
conv_transpose | 用于上采样的转置卷积 |
这些操作支持可变维度(1D、2D、3D)、步长、填充、扩张,以及 channels_first 和 channels_last 两种数据格式。
来源
归一化操作有助于稳定和加速训练
| 操作 | 描述 |
|---|---|
batch_normalization | 根据批次统计数据对激活进行归一化 |
rms_normalization | 均方根归一化 |
normalize | 沿指定轴对张量进行归一化 |
来源
损失函数计算预测值和目标值之间的差异
| 损失函数 | 描述 |
|---|---|
binary_crossentropy | 用于二元分类问题 |
categorical_crossentropy | 用于具有独热编码目标的多元分类 |
sparse_categorical_crossentropy | 用于具有整数目标的多元分类 |
ctc_loss | 用于序列问题的连接主义时间分类损失 |
来源
编码操作将类别数据转换为适合神经网络的格式
| 操作 | 描述 |
|---|---|
one_hot | 将整数转换为独热编码向量 |
multi_hot | 将多个整数转换为单个多热编码向量 |
这两种操作都支持可选的稀疏张量输出,以提高内存效率。
来源
| 操作 | 描述 |
|---|---|
moments | 计算张量的均值和方差 |
dot_product_attention | 实现 Transformer 模型的注意力机制 |
ctc_decode | 从基于 CTC 的模型解码输出 |
psnr | 计算图像质量的峰值信噪比 |
polar | 将复数值表示转换为极坐标 |
来源
当你调用诸如 keras.ops.nn.relu(x) 这样的神经网络操作时,会发生以下过程
来源
神经网络操作可以直接用于自定义层和模型中,提供了一个灵活的底层 API,与 Keras 的高层 API 相辅相成。