例如,一个二维张量:
>>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
以及列表l = [0, 1]
如果我执行t[l]
,那么它将结束打印t
的第0行和第一行。
但是如果我想使用l
作为索引怎么办?我希望使用l
在第0行和第1列中查找元素。换句话说,我希望得到与t[0, 1]
或t[0][1]
相同的结果。
而且我也想在2D以上的尺寸中使用它。使用长度为l
的n
作为索引来跟踪n
尺寸张量中的元素。
我只是写了一个递归函数来解决这个问题,也许有人有一个更优雅的解决方案?
def list_as_index(t, l):
if not l:
return t
else:
return list_as_index(t[l[0]], l[1:])