假设我有一个像这样的 JAX 程序
def f(x: jnp.array) -> jnp.array:
...
def g(x: jnp.array) -> jnp.array:
# use f lots of times
# do other stuff
...
其中
f
和 g
完全 jit
兼容,我应该做什么 jit
?
g
并让JAX/XLA优化f
的多种用途?f
但不是g
?f
和 g
?为什么?
我试图理解
jit
的作用。例如,jit
和XlaBuilder.Build
之间的区别,jit
是否更多是Python/JAX的东西,并且在一般XLA使用中是必需的,以及jit
模式适用于XLA的哪些其他用途。
我还问过这个 XLA 特定但非常相关的问题。
这个问题一般来说很难回答,正确的答案会根据代码的细节而有所不同。但总的来说:包含在
jit
中的任何操作序列都将转换为 HLO 并传递给 XLA 编译器。
JIT 编译有几个好处:
JIT 也有缺点:
考虑到这一点,选择编译的哪些部分进行 JIT 编译就需要根据代码的详细信息来平衡这些成本和收益。
在您的情况下,如果
g
足够短,编译成本不会过于繁重,您应该将其包装在 JIT 中。如果它更长并且包含对 f
的多次调用,且输入具有相同的形状和数据类型,则可能有利于 JIT 编译 f
,但对 g
没有好处。