我目前正在尝试使用 pytorch 为 Alpha Zero 游戏玩家实现 CNN,但我收到了有关矩阵乘法的错误。我的输入由 3 个通道和 10x10 矩阵组成。
model = Net(10, 10**2+1)
print(summary(model,(3,10,10)))
给我以下错误:
RuntimeError Traceback (most recent call last)
<ipython-input-32-0a101a882eb1> in <cell line: 3>()
1 model = Net(10, 10**2+1)
2
----> 3 print(summary(model,(3,10,10)))
9 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) -> str:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (40x10 and 2x101)
这是当前的架构:
def conv3x3(in_planes, out_planes):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1)
def conv1x1(in_planes, out_planes):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, padding=0)
class Net(nn.Module):
def __init__(self, board_size, action_size, num_resBlocks=20, num_hidden=128):
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initial convolution
self.startBlock = nn.Sequential(
conv3x3(3, num_hidden),
nn.BatchNorm2d(num_hidden),
nn.ReLU()
)
# Loop of all 20 Residual Layers
self.backBone = nn.ModuleList(
[ResBlock(num_hidden) for i in range(num_resBlocks)]
)
# Outputs expected value of the state
self.valueHead = nn.Sequential(
conv1x1(num_hidden, 1),
nn.BatchNorm2d(1),
nn.ReLU(),
nn.Linear(in_features=1, out_features=num_hidden),
nn.ReLU(),
nn.Linear(in_features=num_hidden, out_features=1),
nn.Tanh()
)
# Outputs the probabilities of each possible action
self.policyHead = nn.Sequential(
conv1x1(num_hidden, 2),
nn.BatchNorm2d(2),
nn.ReLU(),
nn.Linear(2, out_features=(action_size)),
nn.Softmax(dim=1)
)
self.to(self.device)
def forward(self, x):
x = self.startBlock(x)
for resBlock in self.backBone:
x = resBlock(x)
policy = self.policyHead(x)
value = self.valueHead(x)
return policy, value
class ResBlock(nn.Module):
def __init__(self, num_hidden):
super().__init__()
self.conv1 = conv3x3(num_hidden, num_hidden)
self.bn1 = nn.BatchNorm2d(num_hidden)
self.conv2 = conv3x3(num_hidden, num_hidden)
self.bn2 = nn.BatchNorm2d(num_hidden)
self.relu = nn.ReLU()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# Skip connections
out = self.relu(out + identity)
return out
非常感谢您的帮助,我真的很感激!
backBone
中残差块的输出是一个4D张量形状的(N, 128, h, w)
。然而,PyTorch 中的线性层使用输入 (*, H_in)
和输出 (*, H_out)
。因此,在 valueHead
和 policyHead
中,第一个线性层的输入(nn.Linear(in_features=1, out_features=num_hidden)
代表值头,nn.Linear(2, out_features=(action_size))
代表策略头)需要从 (B, 128, h, w)
排列为 (B, h, w, 128)
。