澄清 pytorch 张量作为参考与值

问题描述 投票:0回答:1

为什么顶部代码

a = mat[0,0]; a = torch.tensor([99])
没有改变
mat
但底部代码
row = mat[0,:]; row[0] = torch.tensor([99])
却改变了?

>>> mat = torch.ones(2,3); print(mat)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
>>> a = mat[0,0]
>>> a = torch.tensor([99]); print(mat)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
>>> row = mat[0,:]
>>> row[0] = torch.tensor([99]); print(mat)
tensor([[99.,  1.,  1.],
        [ 1.,  1.,  1.]])
pytorch pass-by-reference
1个回答
0
投票

运行

a = torch.tensor([99])
时,会将
a
变量的引用从
mat
张量更改为新的
torch.tensor([99])
。这里的赋值改变了变量
a
的含义。

当您运行

row[0] = torch.tensor([99])
时,
row
参考保持不变,但特定项目
row[0]
发生了更改。因为
row
mat
共享内存,所以
mat
也会发生变化。这里的赋值不是改变变量row
,而是改变
row
的特定元素。

您可以更直接地比较两个作业。

mat = torch.ones(2,3) row = mat[0,:] row[0] = torch.tensor([99]) # here we change element `0` of `row` print(mat) # mat is changed mat = torch.ones(2,3) row = mat[0,:] row = torch.tensor([99]) # here we change the variable `row` without changing specific elements print(mat) # mat is unchanged
    
© www.soinside.com 2019 - 2024. All rights reserved.