我正在尝试使用 torch.cat() 来连接火炬张量。但是,我面临错误消息 -->“tuple”对象不支持项目分配。
这是我的代码:
inputs = tokenizer.encode_plus(txt, add_special_tokens=False, return_tensors="pt")
input_id_chunks = inputs["input_ids"][0].split(510)
mask_chunks = inputs["attention_mask"][0].split(510)
print(type(input_id_chunks))
for i in range(len(input_id_chunks)):
print(type(input_id_chunks[i]))
print(input_id_chunks[i])
input_id_chunks[i] = torch.cat([
torch.Tensor([101]), input_id_chunks[i], torch.Tensor([102])
])
输出看起来不错,inputs_id_chunks[i]是torch.Tensor:
`
但是我收到以下打印和错误消息:
类型错误:“元组”对象不支持项目分配
在 torch.cat() 中
我使用了 torch.cat() 的小型测试代码,它工作正常,但我不知道我的原始代码中缺少什么。
您无法更改元组值,而是可以将其分配给列表,然后向其附加新值,然后在要实现的所有更改之后,您应该再次将其分配给元组。
请检查此链接
更清楚地说,
input_id_chunks
是一个元组:
input_id_chunks = inputs["input_ids"][0].split(510)
print(type(inputs_id_chunks)) # this is <class 'tuple'>
然后在
for
循环中分配一个元组值
input_id_chunks[i] = torch.cat([
torch.Tensor([101]), input_id_chunks[i], torch.Tensor([102])
])
这就是您收到错误的原因。
input_id_chunks=list(input_id_chunks)