最近,PyTorch 引入了嵌套张量。但是,如果我创建一个嵌套张量,例如,
import torch
a = torch.randn(20, 128)
nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
然后看它的类类型,它显示:
type(nt)
torch.Tensor
即,类类型只是一个常规的 PyTorch
Tensor
。因此,type(nt) == torch.Tensor
和isinstance(nt, torch.Tensor)
都会返回True
。
所以,我的问题是,有没有办法区分常规张量和嵌套张量?
我能想到的一种方法是,嵌套张量的
size
方法(当前)与常规张量的工作方式不同,因为它需要一个参数,否则会引发 RuntimeError
。所以,解决方案可能是:
def is_nested_tensor(nt):
if not isinstance(nt, torch.Tensor):
return False
try:
# try calling size without an argument
nt.size()
return False
except RuntimeError:
return True
return False
但是有没有更简单的东西不依赖于像
size
方法这样的东西在未来不会改变?