假设我有一个 PyTorch 张量,例如:
import torch x = torch.randn([3, 4, 5])
我想得到一个新的张量,具有相同的维数,包含维度 1 最终值的所有内容。我可以这样做:
x[:, -1:, :]
但是,如果
x
您可以使用
select
dim = 1 # the dimension from which to extract the final values y = x.select(dim, -1).unsqueeze(dim)
其中
unsqueeze