我为 MNIST 上的线性分类器编写了这个简单的自定义网络。
该模型的亮点在于,通过整个网络的“全局邻接矩阵”来执行计算。该矩阵几乎全为零,只有左下角的块非零。 模型本身非常基础,只有两层,没有任何非线性。问题是在学习过程中
邻接矩阵没有得到更新,所以模型没有学习,我不知道为什么。我已经在更标准的架构上测试了我的训练循环,并且一切正常(我使用带有交叉熵损失的 SGD),所以问题一定在于我如何指定网络的类别。对我来说,通过这个全局邻接矩阵进行操作至关重要,我想了解问题出在哪里,以及如何使其发挥作用。
class Simple_Direct_Network_Adjacency_Matrix_Implementation_Dim2(nn.Module):
def __init__(self, input_dim , middle_dim, output_dim):
super().__init__()
self.input_dim = input_dim
_ = middle_dim #This is an hack: we want dim 2 now, so this input to the class gets ignored
self.output_dim = output_dim
self.total_dim = self.input_dim + self.output_dim
self.subdiagonal_block = nn.Parameter(torch.empty(self.output_dim, self.input_dim))
nn.init.normal_(self.subdiagonal_block , mean=0 , std=0.1)
self.adjacency_matrix = self.make_subdiagonal_matrix().requires_grad_(requires_grad=True)
def make_subdiagonal_matrix(self):
over_block = torch.zeros(self.input_dim, self.input_dim)
side_block = torch.zeros(self.total_dim, self.output_dim)
matrix = torch.cat((over_block , self.subdiagonal_block), 0)
matrix = torch.cat((matrix, side_block), 1)
return matrix
def forward(self, batch_of_inputs):
# Flatten the batch of input images
flat_inputs = batch_of_inputs.view(-1 , batch_of_inputs.size(0))
# Append zeros to match
flat_inputs_total = torch.cat((flat_inputs, torch.zeros(self.output_dim , flat_inputs.size(1))), dim=0)
# Perform matrix multiplication
y_total_final = torch.mm(self.adjacency_matrix , flat_inputs_total)
# Extract logits
logits = y_total_final[-self.output_dim: , :].t()
return logits
注意,我也尝试省略requires grad,但没有任何改变,我不知道是否有必要。另请注意,用
nn.Parameter()
指定的参数矩阵也不会改变。另请注意,将邻接矩阵的构造移至前向函数内似乎也无法解决问题..
adjacency_matrix
未更新,因为它不是
nn.Parameter
。您的 subdiagonal_block
未更新,因为它未在您的前向传播中使用。