假设我有两个数组:
z = jnp.array([[5.55751118],
[5.18212974],
[4.35981727],
[3.4559711 ],
[3.35750248],
[2.65199945],
[2.02298999],
[1.59444971],
[0.80865185],
[0.77579791]])
z1 = jnp.array([[ 1.58559484],
[ 3.79094097],
[-0.52712522],
[-1.0178286 ],
[-3.51076985],
[ 1.30108161],
[-1.29824303],
[-0.19209007],
[ 0.37451138],
[-2.33619987]])
我想从数组 z 中的第一行开始,找到第二个矩阵中第二个值在该值的阈值内的位置。
没有@jit的示例:我想返回数组 z1 的最后一个索引。值应为 -3.51x
init = z[0]
distance = 2.6
new = init - distance
def test():
idx = z>=new
val = z1[jnp.where(idx)[0][-1]]
return val
test()
使用 JIT 时(根据较大规模模型的需要)
init = z[0]
distance = 2.6
new = init - distance
@jit
def test():
idx = z>=new
val = z1[jnp.where(idx)[0][-1]]
return val
test()
产生此错误:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function test at /var/folders/ss/pfgdfm2x7_s4cyw2v0b_t7q80000gn/T/ipykernel_85273/75296347.py:9 for jit. This value became a tracer due to JAX operations on these lines:
operation a:bool[10,1] = ge b c
from line /var/folders/ss/pfgdfm2x7_s4cyw2v0b_t7q80000gn/T/ipykernel_85273/75296347.py:11:10 (test)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
问题在于
jnp.where
返回动态大小的数组,而像 jit
这样的 JAX 转换与动态大小的数组不兼容(请参阅 JAX Sharp Bits:动态形状)。您可以将 size
参数传递给 jnp.where
以使结果静态调整大小。由于我们不知道会返回多少个元素,因此我们可以选择最大可能的返回元素数,即idx.shape[0]
。由于结果将用零填充,因此最大索引将给出您要查找的内容:
@jit
def test():
idx = z>=new
val = z1[jnp.where(idx, size=idx.shape[0])[0].max()]
return val
test()