我有一个 JAX 函数,给定顺序和索引,从预定义的字典中选择一个多项式,如下所示:
poly_dict = {
(0, 0): lambda x, y, z: 1.,
(1, 0): lambda x, y, z: x,
(1, 1): lambda x, y, z: y,
(1, 2): lambda x, y, z: z,
(2, 0): lambda x, y, z: x*x,
(2, 1): lambda x, y, z: y*y,
(2, 2): lambda x, y, z: z*z,
(2, 3): lambda x, y, z: x*y,
(2, 4): lambda x, y, z: y*z,
(2, 5): lambda x, y, z: z*x
}
def poly_func(order: int, index: int):
try:
return poly_dict[(order, index)]
except KeyError:
print("(order, index) must be a key in poly_dict!")
return
现在我想jit
poly_func()
,但它给出了一个错误TypeError: unhashable type: 'DynamicJaxprTracer'
而且,如果我只是做
def poly_func(order: int, index: int):
return poly_dict[(order, index)]
它仍然给出相同的错误。有办法解决这个问题吗?
您的方法有两个问题:首先,跟踪值不能用于索引到 Python 集合,如字典或列表。其次,JIT 编译的函数只能返回数组值,不能返回函数。
考虑到这一点,不可能使您建议的功能与 JIT 等 JAX 转换一起工作。但是您可以修改您的方法以使用列表而不是函数字典,然后使用
lax.switch
动态选择您调用的函数。它可能是这样的:
def get_index(order, index):
return order * 5 + index
poly_list = 16 * [lambda x, y, z: 0.0] # place-holder function
for key, val in poly_dict.items():
poly_list[get_index(*key)] = poly_dict[key]
def eval_poly_func(order: int, index: int, args):
ind = get_index(order, index)
return jax.lax.switch(ind, poly_list, *args)
result = jax.jit(eval_poly_func)(2, 0, (5.0, 6.0, 7.0))
print(result)
# 25.0