我正在尝试优化的大型代码。您在下面看到的代码部分是一个for循环,该循环返回张量中的编码。我如何在常规列表中输出这些数字,而不去张量?
def _make_batches(self, lines):
tokens = [self._tokenize(line) for line in lines]
lengths = np.array([t.numel() for t in tokens])
indices = np.argsort(-lengths, kind=self.sort_kind) # pylint: disable=invalid-unary-operand-type
def batch(tokens, lengths, indices):
toks = tokens[0].new_full((len(tokens), tokens[0].shape[0]),
self.pad_index)
for i in range(len(tokens)):
toks[i, -tokens[i].shape[0]:] = tokens[i]
return Batch(srcs=None,
tokens=toks,
lengths=torch.LongTensor(lengths)), indices
batch_tokens, batch_lengths, batch_indices = [], [], []
ntokens = nsentences = 0
for i in indices:
if nsentences > 0 and ((self.max_tokens is not None
and ntokens + lengths[i] > self.max_tokens)
or (self.max_sentences is not None
and nsentences == self.max_sentences)):
yield batch(batch_tokens, batch_lengths, batch_indices)
ntokens = nsentences = 0
batch_tokens, batch_lengths, batch_indices = [], [], []
batch_tokens.append(tokens[i])
batch_lengths.append(lengths[i])
batch_indices.append(i)
ntokens += tokens[i].shape[0]
nsentences += 1
if nsentences > 0:
yield batch(batch_tokens, batch_lengths, batch_indices)
这就是我所说的函数:
if __name__ == '__main__':
s = SentenceEncoder("data/model.pt")
input = [args.string_enc]
make_batches = s._make_batches
print([batch[1] for batch, indexes in make_batches(input)])
输出为:
[tensor([[29733, 20720, 2]])]
所需的输出是:
[29733, 20720, 2]
您的意思是这个?
a=[torch.tensor([[29733, 20720, 2]])]
b=a[0].squeeze(0).tolist()
print(b)