我正在处理 csv。我嵌入了一个列并转换为张量。就像,
tensor([-1.7110e-01, 1.3811e-01, -2.5881e-01, -1.8281e-01, -3.3073e-01,
-1.1071e-01])
将这些张量保存为新列并将其保存到 csv 中。现在,当我再次加载该 csv 并查看嵌入列的值时,它看起来像这样
'tensor([-1.7110e-01, 1.3811e-01, -2.5881e-01, -1.8281e-01, -3.3073e-01,\n -1.1071e-01])'
让我知道如何将其转换回来并使用?
显然,您写入 CSV 文件的内容是张量对象的
repr
,这就是您打印对象时看到的内容。要正确处理 CSV 文件中的张量,您应该将张量对象序列化为正确的数据格式以供导出,并在加载时反序列化。
以下是如何实现此目的的示例:
import csv
import ast
# Sample tensor
sample_tensor = torch.tensor([-1.7110e-01, 1.3811e-01, -2.5881e-01, -1.8281e-01, -3.3073e-01, -1.1071e-01])
# Serialize the tensor as a list and write to CSV
csv_file = 'tensor_data.csv'
with open(csv_file, 'w') as file:
writer = csv.writer(file)
writer.writerow([sample_tensor.tolist()])
# Read the CSV file and reconstruct the tensor
with open(csv_file, 'r') as file:
reader = csv.reader(file)
serialized_tensor = next(reader)[0]
reconstructed_tensor = torch.tensor(ast.literal_eval(serialized_tensor))
# Print the reconstructed tensor
print(reconstructed_tensor)