如何区分 PyTorch 张量和嵌套张量?

问题描述 投票:0回答:1

最近,PyTorch 引入了嵌套张量。但是,如果我创建一个嵌套张量,例如,

import torch

a = torch.randn(20, 128)
nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)

然后看它的类类型,它显示:

type(nt)
torch.Tensor

即,类类型只是一个常规的 PyTorch

Tensor
。因此,
type(nt) == torch.Tensor
isinstance(nt, torch.Tensor)
都会返回
True

所以,我的问题是,有没有办法区分常规张量和嵌套张量?

我能想到的一种方法是,嵌套张量的

size
方法(当前)与常规张量的工作方式不同,因为它需要一个参数,否则会引发
RuntimeError
。所以,解决方案可能是:

def is_nested_tensor(nt):
    if not isinstance(nt, torch.Tensor):
        return False

    try:
        # try calling size without an argument
        nt.size()
        return False
    except RuntimeError:
        return True

    return False

但是有没有更简单的东西不依赖于像

size
方法这样的东西在未来不会改变?

python pytorch tensor
1个回答
1
投票

所有

torch.Tensor
都有一个名为 is_nested 的属性,但遗憾的是它没有记录在案。仅在FAQ

中提到
> nt.is_nested
True
> a.is_nested
False
© www.soinside.com 2019 - 2024. All rights reserved.