Jit 一个从字典中选择函数的 JAX 函数

问题描述 投票:0回答:1

我有一个 JAX 函数,给定顺序和索引,从预定义的字典中选择一个多项式,如下所示:

poly_dict = {
    (0, 0): lambda x, y, z: 1.,
    (1, 0): lambda x, y, z: x,
    (1, 1): lambda x, y, z: y,
    (1, 2): lambda x, y, z: z,
    (2, 0): lambda x, y, z: x*x,
    (2, 1): lambda x, y, z: y*y,
    (2, 2): lambda x, y, z: z*z,
    (2, 3): lambda x, y, z: x*y,
    (2, 4): lambda x, y, z: y*z,
    (2, 5): lambda x, y, z: z*x
}

def poly_func(order: int, index: int):
    
    try:
        return poly_dict[(order, index)]
    
    except KeyError:
        print("(order, index) must be a key in poly_dict!")
        return

现在我想jit

poly_func()
,但它给出了一个错误
TypeError: unhashable type: 'DynamicJaxprTracer'
而且,如果我只是做

def poly_func(order: int, index: int):
    
    return poly_dict[(order, index)]

它仍然给出相同的错误。有办法解决这个问题吗?

python dictionary conditional-statements jit jax
1个回答
0
投票

您的方法有两个问题:首先,跟踪值不能用于索引到 Python 集合,如字典或列表。其次,JIT 编译的函数只能返回数组值,不能返回函数。

考虑到这一点,不可能使您建议的功能与 JIT 等 JAX 转换一起工作。但是您可以修改您的方法以使用列表而不是函数字典,然后使用

lax.switch
动态选择您调用的函数。它可能是这样的:

def get_index(order, index):
  return order * 5 + index

poly_list = 16 * [lambda x, y, z: 0.0]  # place-holder function

for key, val in poly_dict.items():
  poly_list[get_index(*key)] = poly_dict[key]


def eval_poly_func(order: int, index: int, args):
  ind = get_index(order, index)
  return jax.lax.switch(ind, poly_list, *args)

result = jax.jit(eval_poly_func)(2, 0, (5.0, 6.0, 7.0))
print(result)
# 25.0
© www.soinside.com 2019 - 2024. All rights reserved.