移除张量以优化python中的for循环

问题描述 投票:0回答:1

我正在尝试优化的大型代码。您在下面看到的代码部分是一个for循环,该循环返回张量中的编码。我如何在常规列表中输出这些数字,而不去张量?

def _make_batches(self, lines):
        tokens = [self._tokenize(line) for line in lines]
        lengths = np.array([t.numel() for t in tokens])
        indices = np.argsort(-lengths, kind=self.sort_kind)  # pylint: disable=invalid-unary-operand-type

        def batch(tokens, lengths, indices):
            toks = tokens[0].new_full((len(tokens), tokens[0].shape[0]),
                                      self.pad_index)
            for i in range(len(tokens)):
                toks[i, -tokens[i].shape[0]:] = tokens[i]
            return Batch(srcs=None,
                         tokens=toks,
                         lengths=torch.LongTensor(lengths)), indices

        batch_tokens, batch_lengths, batch_indices = [], [], []
        ntokens = nsentences = 0
        for i in indices:
            if nsentences > 0 and ((self.max_tokens is not None
                                    and ntokens + lengths[i] > self.max_tokens)
                                   or (self.max_sentences is not None
                                       and nsentences == self.max_sentences)):
                yield batch(batch_tokens, batch_lengths, batch_indices)
                ntokens = nsentences = 0
                batch_tokens, batch_lengths, batch_indices = [], [], []
            batch_tokens.append(tokens[i])
            batch_lengths.append(lengths[i])
            batch_indices.append(i)
            ntokens += tokens[i].shape[0]
            nsentences += 1
        if nsentences > 0:
            yield batch(batch_tokens, batch_lengths, batch_indices)

这就是我所说的函数:

if __name__ == '__main__':
    s = SentenceEncoder("data/model.pt")
    input = [args.string_enc]
    make_batches = s._make_batches
    print([batch[1] for batch, indexes in make_batches(input)])

输出为:

[tensor([[29733, 20720,     2]])]

所需的输出是:

[29733, 20720,     2]
python for-loop pytorch tensor
1个回答
0
投票

您的意思是这个?

a=[torch.tensor([[29733, 20720,     2]])]
b=a[0].squeeze(0).tolist()
print(b)
© www.soinside.com 2019 - 2024. All rights reserved.