如何使用for循环转换pytorch张量列表的张量类型

问题描述 投票:0回答:1

我正在尝试将张量的类型从 DoubleTensor 转换为 FloatTensor。但是,我的代码中似乎没有转换张量。如何使用 for 循环转换张量?

train_tensorset = [xq_train_tensor,y_train_tensor]
val_tensorset = [xq_val_tensor,y_val_tensor]
test_tensorset = [xq_test_tensor,y_test_tensor]

tensor_list = [train_tensorset,val_tensorset,test_tensorset]
flat_tensor_list = list(itertools.chain.from_iterable(tensor_list))
print(f"Num tensors:{len(flat_tensor_list)}")

for i, tensor in enumerate(flat_tensor_list):
    tensor = flat_tensor_list[i].float()
    print(f"{i}: {tensor.type()}") 

print(xq_train_tensor.type())

为了验证张量是否正在转换,我检查

xq_train_tensor
的类型,它是
flat_tensor_list
的一部分。这是上面代码的输出:

Num tensors:6
0: torch.FloatTensor
1: torch.FloatTensor
2: torch.FloatTensor
3: torch.FloatTensor
4: torch.FloatTensor
5: torch.FloatTensor
xq_train_tensor: torch.DoubleTensor

尽管

flat_tensor_list
中的所有项目都在 for 循环中进行转换,但张量似乎并未真正被转换。

python for-loop pytorch tensor
1个回答
0
投票

当您在张量上调用

.float()
时,您将返回一个新对象。

当您在代码末尾运行

print(xq_train_tensor.type())
时,您正在引用旧对象,该对象仍然是 double 类型。

如果希望变量

xq_train_tensor
更新为float类型,则需要给变量本身重新赋值。

xq_train_tensor = xq_train_tensor.float()

如果要使用列表迭代格式,可以将新的浮点张量保存到新列表中。但是,这不会更新

xq_train_tensor
变量,因为它仍然指向旧的张量。

output_tensors = []
for i, tensor in enumerate(flat_tensor_list):
    output_tensors.append(tensor.float())

for i, tensor in enumerate(output_tensors):
    print(f"{i}: {tensor.type()}")
© www.soinside.com 2019 - 2024. All rights reserved.