如何使用 jit 编译和 vmap 自动矢量化对 JAX 函数进行矢量化

问题描述 投票:0回答:1

如何在 JAX 中使用 jit 和 vmap 来矢量化和加速以下计算:

@jit
def distance(X, Y):
    """Compute distance between two matrices X and Y.

    Args:
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return jnp.mean(jnp.abs(X - Y))

@jit
def compute_metrics(idxs, X, Y):
    results = []
    # Iterate over idxs
    for i in idxs:
        if i:
            results.append(distance(X[:, i], Y[:, i]))
    return results

#data
X = np.random.rand(600, 10)
Y = np.random.rand(600, 10)
#indices
idxs = ((7,8), (7,9), (), (), ())

# call the regular function
print(compute_metrics(idxs, X, Y)) # works
# call the function with vmap
print(vmap(compute_metrics, in_axes=(None, 0, 0))(idxs, X, Y)) # doesn't work

我关注了 JAX 网站和教程,但我无法找到如何进行这项工作。非 vmap 版本有效。但是,我得到了 vmap 版本(上面最后一行)的 IndexError,如下所示:

jax._src.traceback_util.UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

知道如何让它工作吗? idxs 也可能会改变并且是任意有效的索引组合,例如

idxs = ((), (3,9), (3,2), (), (5,8))

如上所述,我尝试了带有和不带有 vmap 的上述版本,但无法使后者 vmap 版本正常工作。

vectorization jit jax auto-vectorization google-jax
1个回答
1
投票

我不认为 vmap 会与标量元组一起工作。您需要的是将索引放入数组并对其进行 vmap。

我不确定这个解决方案是否让您满意,因为我们必须摆脱空索引对 ()。

idxs_pairs = jnp.array([[7,8],[7,9]]) # put the indices pairs into array

@jit
def distance(X, Y):
    """Compute distance between two matrices X and Y.

    Args:
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return jnp.mean(jnp.abs(X - Y))

@jit
def compute_metrics(idxs, X, Y):
    return distance(X[:,idxs], Y[:,idxs])

vmap(compute_metrics, in_axes=(0, None, None))(idxs_pairs, X, Y)

你也可以 jit 一切:

jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)
© www.soinside.com 2019 - 2024. All rights reserved.