jax.jit 我的整个程序还是其中的一部分?

问题描述 投票:0回答:1

假设我有一个像这样的 JAX 程序

def f(x: jnp.array) -> jnp.array:
    ...

def g(x: jnp.array) -> jnp.array:
   # use f lots of times
   # do other stuff
   ...

其中

f
g
完全
jit
兼容,我应该做什么
jit

  1. g
    并让JAX/XLA优化
    f
    的多种用途?
  2. f
    但不是
    g
  3. f
    g

为什么?

我试图理解

jit
的作用。例如,
jit
XlaBuilder.Build
之间的区别,
jit
是否更多是Python/JAX的东西,并且在一般XLA使用中是必需的,以及
jit
模式适用于XLA的哪些其他用途。

我还问过这个 XLA 特定但非常相关的问题。

python jit jax
1个回答
0
投票

这个问题一般来说很难回答,正确的答案会根据代码的细节而有所不同。但总的来说:包含在

jit
中的任何操作序列都将转换为 HLO 并传递给 XLA 编译器。

JIT 编译有几个好处:

  • 编译器具有融合和/或删除函数中操作的逻辑,使整个计算更加高效。
  • 在运行时,对于每个 JIT 编译的函数调用,您只会产生一次 JAX 的几毫秒 python 调度开销,而不是在没有 JIT 的情况下每个操作一次。

JIT 也有缺点:

  • 编译不是免费的,一般来说,编译成本会随着 JIT 编译函数中的操作数量大约呈二次方增长。
  • 编译特定于输入数组的形状和数据类型,因此在不同大小的数组上调用已编译函数将产生重新编译的成本。

考虑到这一点,选择编译的哪些部分进行 JIT 编译就需要根据代码的详细信息来平衡这些成本和收益。

在您的情况下,如果

g
足够短,编译成本不会过于繁重,您应该将其包装在 JIT 中。如果它更长并且包含对
f
的多次调用,且输入具有相同的形状和数据类型,则可能有利于 JIT 编译
f
,但对
g
没有好处。

© www.soinside.com 2019 - 2024. All rights reserved.