我有 4 个大小为 64x64 的矩阵,它们被堆叠起来 (Torch.Stack) 以创建大小为 [4,64,64] 的矩阵,它们将作为我的 TensorDataSet 的输入。我有 1 个 64x64 矩阵,旨在为我的 TensorDataSet 输出。当我将它们加载到 TensorDataSet(输入,输出)时,我发现大小不匹配。
如果我采用 1 个输入和 1 个输出,每个输入和输出的大小均为 64x64,TensorDataSet 将接受这一点。但是,我想传递与 1 个输出值相对应的 4 个输入值。例如,每个输入的 [0,0] 位置中的第一个值与输出的 [0,0] 位置有关系。
我尝试使用挤压方法,但没有成功。
TensorDataset
引发的错误的根源,您将看到:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
这意味着所有张量沿第一维(数据集大小)必须具有相同的大小。
在您的情况下,一个
4x64x64
张量对应于 1 个 64x64,换句话说,输入张量必须分别针对输入和输出进行整形 (1,4,64,64)
和 (1,64,64)
。因此,您需要在两者上解压缩额外的维度(使用 None
索引或 unsqueeze
):
x = torch.rand(4,64,64)
y = torch.rand(64,64)
dataset = TensorDataset(x[None], y[None])