如何用非标量值填充 pytorch 张量?
例如,假设我想用形状为
X
的一维火炬矢量(n_samples, n_classes)
填充形状为a
的火炬张量(n_classes,)
。理想情况下,我希望能够写:
X = torch.full((n_samples, n_classes), a)
其中向量
a
是fill_value
中的torch.full
。然而,torch.full
只接受一个标量作为 fill_value
(Source)。所以这段代码不起作用。
我有两个问题:
torch.full
,用X
的n_sample
副本填充a
的快速方法是什么?torch.full
只接受标量填充值? torch.full
实现不能接受张量填充值是否有充分的理由?关于问题1.,我想简单写一下:
X = torch.ones((n_samples, n_classes)) * a
但是,有没有更快/更有效的方法来做到这一点?
作为参考,我已经查看了以下堆栈溢出帖子
但这些都没有直接回答我的问题。
谢谢!