torch.vmap 内批量创建张量

问题描述 投票:0回答:1

我想根据函数输入的形状使用

torch.zeros
创建一个张量。然后我想用
torch.vmap
对函数进行向量化。

类似这样的:

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])

def polycompanion(polynomial):
    deg = polynomial.shape[-1] - 2
    companion = torch.zeros((deg+1, deg+1))
    companion[1:,:-1] = torch.eye(deg)
    companion[:,-1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion

polycompanion_vmap = torch.vmap(polycompanion)
print(polycompanion_vmap(poly_batched))

问题是批处理版本将不起作用,因为

companion
不会是
BatchedTensor
,与输入的
polynomial
不同。

有一个解决方法:

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])

def polycompanion(polynomial,companion):
    deg = companion.shape[-1] - 1
    companion[1:,:-1] = torch.eye(deg)
    companion[:,-1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion

polycompanion_vmap = torch.vmap(polycompanion)

print(polycompanion_vmap(poly_batched, torch.zeros(poly_batched.shape[0],poly_batched.shape[-1]-1, poly_batched.shape[-1]-1)))

输出:

tensor([[[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]],

        [[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]]])

但这很难看。

这个问题有解决办法吗?以后会支持吗?

注意:如果您在函数的输入上使用

torch.zeros_like
,它会起作用并创建
BatchedTensor
,但这对我没有帮助。

预先感谢您的帮助!

python pytorch vectorization tensor torch
1个回答
0
投票

我们可以

clone()
达到我们想要进行就地操作的尺寸,然后到
concatenate
tensor
,我们必须使用
shape
正确匹配它们的
None
来索引非存在维度:

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])

def polycompanion(polynomial):
    deg = polynomial.shape[-1] - 2
    companion = torch.zeros((deg+1, deg+1))
    companion[1:,:-1] = torch.eye(deg)
    _companion = torch.concatenate([companion[:, :-1].clone(), (-1. * polynomial[:-1] / polynomial[-1])[:, None]], dim=1)
    return _companion

polycompanion_vmap = torch.vmap(polycompanion)
print(polycompanion_vmap(poly_batched))

输出:

tensor([[[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]],

        [[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]]])
© www.soinside.com 2019 - 2024. All rights reserved.