我最近开始使用 JAX,作为 numpy 的交换,并且我没有太多考虑测试它是否真的比 numpy 更好,直到我开始探索矢量化,并找到这篇文章:https://www.programiz .com/python-programming/numpy/vectorization 给出了 numpy 的示例,并声称使用数组比循环更快。
我对这个说法没有任何问题,除了他们的例子显示了完全相反的东西!?在循环中迭代并将数字添加到列表中比在数组中执行更快?!循环时间的输出:
For loop time: 4.76837158203125e-06
和 vec 时间的输出:Vectorization time: 1.5020370483398438e-05
。那个 4.768e-06
明显小于:1.502e-05
。
这促使我尝试复制该实验。所以我写了这个简单的代码来比较 for 循环、numpy 和 jax:
import numpy as np
import jax.numpy as jnp
import time
array1 = [1, 2, 3, 4, 5]
for n in range(20):
start = time.time()
result = [1, 2, 3, 4, 5]
for i in range(len(result)):
result[i] += 10
end = time.time()
loop_time = end - start
print("For loop time:", loop_time, "array1: ", result)
array1 = np.array([1, 2, 3, 4, 5 ])
for n in range(20):
start = time.time()
result_np = array1 + 10
end = time.time()
vec_time = end - start
print("Vec time: ", vec_time, "result: ", result_np)
array1 = jnp.array([1, 2, 3, 4, 5 ])
for n in range(20):
start = time.time()
result_jax = array1 + 10
end = time.time()
jax_time = end - start
print("Jax vec time:", jax_time, "result: ", result_jax)
为了给 jax 均匀的机会,我在循环中运行了相同的操作 20 次,仅测量最后一次操作的时间。运行一次并不能真正使 JAX 和 numpy 受益,如果多次运行相同的操作,两者似乎都会优化。 (我有根据的猜测是 numpy 不需要分配内存,而 JAX 在幕后用它的编译器做了一些其他的魔法)。
我尝试了不同数量的循环、初始化等。但是我每次得到的输出:
For loop time: 1.9073486328125e-06 array1: [11, 12, 13, 14, 15]
Vec time: 2.384185791015625e-06 result: [11 12 13 14 15]
Jax vec time: 9.083747863769531e-05 result: [11 12 13 14 15]
这清楚地表明 JAX 远远落后于 numpy 和循环。
如果说 JAX 更糟糕,那就太牵强了,但这让我想知道:如果 JAX 在这样一个基本示例上没有表现得更好,那么我们如何利用它来超越 numpy 或简单循环呢? (顺便说一句,它在 CUDA 上运行)。你能推荐一些例子来尝试吗?
...总结:如果您正在对 CPU 上的各个数组操作进行微基准测试,您通常可以预期 NumPy 的性能优于 JAX,因为它的每次操作调度开销较低。如果您在 GPU 或 TPU 上运行代码,或者在 CPU 上对更复杂的 JIT 编译操作序列进行基准测试,则通常可以预期 JAX 的性能优于 NumPy。
您的基准脚本完全符合我们预期 JAX 比 NumPy 慢的情况。在实践中,这种调度开销并不那么重要,因为任何感兴趣的现实世界程序都会比仅仅添加五个整数要复杂得多。