菜单

NumPy兼容操作

相关源文件

本wiki页面介绍了 Keras 3 中的 NumPy 兼容操作系统,它提供了一个后端无关的 API,用于在所有支持的后端(TensorFlow、JAX、PyTorch、NumPy 和 OpenVINO)上执行类 NumPy 操作。这些操作通过 keras.ops.numpy 命名空间(通常导入为 knp)提供。

有关神经网络特定操作,请参阅神经网络操作

概述

NumPy 兼容操作系统允许开发者编写在所有后端上一致运行的代码,从而实现可移植的模型定义。这些操作模仿 NumPy 的行为和 API,同时支持符号张量(模型构建期间)和具体张量(即时执行期间)。

来源

设计与实现

NumPy 兼容操作系统由三个主要组件组成

  1. 操作类:每个 NumPy 操作都定义为 Operation 的子类,实现 call()compute_output_spec() 方法。
  2. API 函数:每个操作都有一个对应的 API 函数,使用 @keras_export 装饰器进行装饰。
  3. 后端实现:每个后端都在其自己的 numpy.py 模块中实现具体功能。

来源

执行流程

当调用 NumPy 操作时,它遵循以下执行流程

  1. API 入口点:用户调用诸如 keras.ops.numpy.add(x1, x2) 的函数
  2. 符号检查:函数检查任何输入是否为符号张量(KerasTensor
  3. 分派路径:
    • 如果为符号:调用 Operation.symbolic_call(),它会调用 compute_output_spec() 来确定输出形状和数据类型
    • 如果为即时:通过 backend.numpy.add(x1, x2) 调用后端特定实现

让我们以 add 操作为例,看看它是如何工作的

来源

可用操作

NumPy 兼容操作系统提供了广泛的操作,分类如下

数组创建操作

  • array:创建数组
  • ones:创建全一数组
  • zeros:创建全零数组
  • eye:创建单位矩阵
  • arange:在给定区间内创建等间隔值
  • linspace:在指定范围内创建等间隔值
  • full:创建指定形状和值的常量数组
  • meshgrid:从坐标向量创建坐标矩阵

数组操作

  • reshape:在不改变数据的情况下为数组赋予新形状
  • transpose:对数组维度进行转置
  • squeeze:从数组形状中移除单维度条目
  • expand_dims:用新轴扩展数组的形状
  • concatenate:沿着现有轴连接数组
  • stack:沿着新轴堆叠数组
  • split:将数组拆分为多个子数组
  • tile:通过重复输入数组来构造数组
  • pad:填充数组

数学运算

  • add, subtract, multiply, divide:基本算术运算
  • matmul:矩阵乘法
  • power:将元素提升到指定幂
  • exp, log, sqrt:常见数学函数
  • sin, cos, tan:三角函数
  • mean, sum, max, min:约简操作
  • clip:将数组值剪辑到指定范围
  • abs, absolute:计算绝对值

比较操作

  • equal, not_equal:元素级相等比较
  • greater, greater_equal:元素级大于/大于等于比较
  • less, less_equal:元素级小于/小于等于比较
  • logical_and, logical_or, logical_not:逻辑运算

线性代数操作

  • dot:两个数组的点积
  • tensordot:指定轴上的张量收缩
  • einsum:爱因斯坦求和约定
  • inner:数组的内积
  • outer:计算两个向量的外积
  • trace:沿着数组对角线求和

下面是一些常用操作及其签名的表格

操作签名描述
addadd(x1, x2)逐元素加法
subtractsubtract(x1, x2)元素级减法
multiplymultiply(x1, x2)元素级乘法
dividedivide(x1, x2)元素级除法
matmulmatmul(x1, x2)矩阵乘法
meanmean(x, axis=None, keepdims=False)沿着指定轴的平均值
sumsum(x, axis=None, keepdims=False)沿着指定轴求和
maxmax(x, axis=None, keepdims=False, initial=None)沿着轴的最大值
reshapereshape(x, newshape)不改变数据的情况下重塑数组
transposetranspose(x, axes=None)置换数组维度

来源

使用 NumPy 操作

基本用法

NumPy 兼容操作可以按如下方式导入和使用

符号张量的形状推断

NumPy 兼容操作的关键特性之一是它们能够与符号张量(由 KerasTensor 表示)一起工作。每个操作的 compute_output_spec 方法根据输入张量确定输出形状和数据类型。

例如,Add 操作通过广播输入形状来确定输出形状

该系统智能地处理动态形状,即使某些维度未知(表示为 None),也能使用广播规则确定结果形状。

来源

后端特定实现细节

虽然 API 在不同后端之间保持一致,但实现细节有所不同

  1. TensorFlow 后端:

    • 专门处理稀疏张量
    • 优化某些操作,如 einsum,通过自定义实现获得更好的性能
    • 尽可能使用硬件加速(例如,用于 int8 操作)
  2. JAX 后端:

    • 使用 JAX 的函数式范式
    • 处理精度和数值稳定性的特殊情况
    • 通过 JAX 的稀疏库支持稀疏张量
  3. PyTorch 后端:

    • 调整 PyTorch 的张量操作以匹配 NumPy 语义
    • 处理设备放置(CPU/GPU)
    • 解决 PyTorch 特定的限制(例如,类型支持)
  4. NumPy 后端:

    • 直接传递给 NumPy 操作
    • 处理类型转换和标准化
  5. OpenVINO 后端:

来源

实现架构

NumPy 兼容操作系统的实现架构包括

操作的构成

每个操作遵循以下模式

  1. 定义一个 Operation 子类,包含

    • __init__ 用于存储任何参数(可选)
    • call() 用于实现操作逻辑
    • compute_output_spec() 用于确定输出形状和数据类型
  2. 定义一个 API 函数,该函数

    • 使用 any_symbolic_tensors() 检查是否有任何输入是符号张量
    • 如果为符号,则调用操作的 symbolic_call() 方法
    • 如果为即时,则直接调用后端实现

这是 Add 操作的示例

来源

处理特殊情况

广播

NumPy 兼容操作系统实现了 NumPy 的广播规则,以处理不同形状的张量。这通过诸如 broadcast_shapes() 的实用函数进行管理,该函数在组合不同维度的张量时确定输出形状。

来源

动态形状

该系统支持动态形状,其中某些维度在图构建时可能未知

来源

类型处理

操作通过使用 dtypes.result_type() 来根据输入数据类型确定输出数据类型,从而处理不同的数据类型。这确保了不同后端之间的一致行为,并遵循 NumPy 的数据类型提升规则。

来源

性能考量

特定后端优化

不同的后端实现了各种优化

  1. TensorFlow:

    • 在可能的情况下,将 tf.einsum 与硬件加速结合使用
    • 实现操作的专用版本以获得更好的性能
    • 在适合融合操作时使用 tf.nn.bias_add
  2. JAX:

    • 利用 JAX 的 XLA 编译实现高效执行
    • 实现精度和数值稳定性的特殊处理
    • 使用 JAX 的硬件加速操作
  3. PyTorch:

    • 在可用时实现操作的专用 CUDA 版本
    • 处理 CPU/GPU 执行的设备放置
    • 使用 PyTorch 的优化操作

来源

硬件加速

NumPy 操作在可能的情况下会进行硬件加速

  • 矩阵乘法:使用硬件加速的矩阵乘法操作
  • 整数操作:在支持的硬件上,int8 操作会得到加速
  • 批处理操作:高效地分派到并行执行

例如,当满足正确条件时,TensorFlow 会使用硬件加速实现

来源

测试与兼容性

NumPy 兼容操作系统经过广泛测试,以确保不同后端之间的一致行为。测试验证了

  1. 静态形状推断
  2. 动态形状推断
  3. 数据类型处理
  4. 数值正确性
  5. 广播行为

某些操作可能并非在所有后端都得到完全支持。例如,OpenVINO 有一个尚未实现的操作排除列表

NumPyOneInputOpsCorrectnessTest::test_flip
NumPyOneInputOpsCorrectnessTest::test_meshgrid
NumPyTwoInputOpsCorrectnessTest::test_cross
NumPyTwoInputOpsCorrectnessTest::test_digitize

来源

结论

Keras 3 中的 NumPy 兼容操作系统提供了一个一致的、后端无关的 API,用于执行数组操作。它使开发者能够编写在多个后端上工作且支持符号执行和即时执行模式的可移植代码。该系统处理形状推断、数据类型提升和广播,同时利用每个后端的优化实现来提高性能。

通过模仿熟悉的 NumPy API,Keras 使数据科学家和机器学习从业者更容易从 NumPy 过渡到深度学习框架,而无需学习基本数组操作的新 API。