使用 ndarray 和 PyTorch 张量进行 Numpy 索引

问题描述 投票:0回答:1

我发现 numpy 数组索引与 ndrrray 和形状为

(1,)
的 PyTorch 张量的工作方式不同,并且想知道为什么。请看下面的案例:

import numpy as np
import torch as th

x = np.arange(10)

y = x[np.array([1])]
z = x[th.tensor([1])]
print(y, z)

y
将是
array[2]
,而
z
只是
2
。到底有什么区别?

python numpy pytorch numpy-ndarray
1个回答
0
投票

请注意,单个元素的整数张量可以转换为索引:

>>> torch.tensor([1]).__index__()
1
>>> torch.tensor([1, 2]).__index__()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: only integer tensors of a single element can be converted to an index

当传入的索引是张量时,

ndarray
无法识别它,因此它尝试调用其
__index__
方法。如果转换成功,则被视为整数:

if (PyLong_CheckExact(obj) || !PyArray_Check(obj)) {
    npy_intp ind = PyArray_PyIntAsIntp(obj);  // it calls PyNumber_Index() internally

    if (error_converting(ind)) {
        PyErr_Clear();
    }
    else {
        index_type |= HAS_INTEGER;
        indices[curr_idx].object = NULL;
        indices[curr_idx].value = ind;
        indices[curr_idx].type = HAS_INTEGER;
        used_ndim += 1;
        new_ndim += 0;
        curr_idx += 1;
        continue;
    }
}

来自 NumPy 的源代码。

© www.soinside.com 2019 - 2024. All rights reserved.