如何使用pytorch制作具有多个输出(和多个类)的神经网络?

问题描述 投票:0回答:1

我正在研究多输出(即> 1个输出目标)多类(即> 1类)(我相信这也称为多任务问题)。 例如,我的 train_features_data 的形状为 (4, 6)(即三行/示例和 6 列/特征),我的 train_target_data 的形状为 (4, 3)(即 4 行/示例和 3 列/目标)。对于每个目标,我有三个不同的类别(-1、0、1)。

我为这个问题定义了一个示例模型架构(和数据),如下所示:

import pandas as pd
from torch import nn 
from logging import log
import torch
feature_data = {
    'A': [1, 2, 3, 4],
    'B': [5, 6, 7, 8],
    'C': [9, 10, 11, 12],
    'D': [13, 14, 15, 16],
    'E': [17, 18, 19, 20],
    'F': [21, 22, 23, 24]
}

target_data = {
    'Col1': [1, -1, 0, 1],
    'Col2': [-1, 0, 1, -1],
    'Col3': [-1, 0, 1, 1]
}

# Create the DataFrame
train_feature_data = pd.DataFrame(feature_data) 
train_target_data = pd.DataFrame(target_data)
device = "cuda" if torch.cuda.is_available() else "cpu"

# create the model
class MyModel(nn.Module):
  def __init__(self, inputs=6, l1=12, outputs=3):
      super().__init__()
      self.sequence = nn.Sequential(
        nn.Linear(inputs, l1),
        nn.Linear(l1, outputs),
        nn.Softmax(dim=1)
    )
      
  def forward(self, x):
      x = self.sequence(x)
      return x
    
x_train = torch.tensor(train_feature_data.to_numpy()).type(torch.float)
model = MyModel(inputs = 6, l1 = 12, outputs = 3).to(device)
model(x_train.to(device=device))

当我将训练数据传递到模型中时(即当我调用 model(x_train.to(device=device)) 时),我会返回一个形状数组 (4, 3)。

通过遵循此资源资源,我的期望是我会得到类似 (4, 3, 3) 的形状,其中第一个轴(即 4)是我的特征和目标数据中的示例数量,第二个轴(即中间的 3)代表每个示例的 logits(或者在这种情况下,因为我有一个 softmax 函数,这将是预测概率)(这将是 3,因为我有三个类),而第三个轴(或形状中最右边的 3 个值)代表我的 train_target_data 中的输出/列数。

有人可以提供一些指导,说明我在这里做错了什么(如果我的方法是错误的)以及如何修复它。谢谢。

python deep-learning pytorch neural-network
1个回答
0
投票

您的模型将形状

(4, 6)
的输入映射到第一个线性层中的
(4, 12)
,然后映射到第二层中的
(4, 3)

如果想要输出的形状为

(4, 3, 3)
,则需要有第二层输出
(4, 3*3)
,然后reshape。

n_problems = 3
classes_per_problem = 3

model = nn.Linear(6, n_problems*classes_per_problem)

x = torch.randn(4, 6)
x1 = model(x)
bs, _ = x1.shape
x1 = x1.reshape(bs, classes_per_problem, n_problems)

y = torch.randint(high=classes_per_problem, size=(bs, n_problems))
loss_function = nn.CrossEntropyLoss()

loss = loss_function(x1, y)
© www.soinside.com 2019 - 2024. All rights reserved.