为什么pytorch的nn.ModuleList不支持矢量化?下面的代码有性能问题吗?

问题描述 投票:0回答:1

我的模型中有很多子网络,对于不同的样本,我想使用不同的子网络进行一些计算。经过一番搜索,我发现了以下方法。为什么pytorch的nn.ModuleList不支持矢量化?这个方法有性能问题吗?

import torch
import torch.nn as nn


class SingleVariableNetwork(nn.Module):
    def __init__(self, init_value):
        super(SingleVariableNetwork, self).__init__()
        self.v = torch.tensor([init_value], dtype=torch.int32) 
    def forward(self, x):
        return self.v * x

class IndexedNetwork(nn.Module):
    def __init__(self, networks):
        super(IndexedNetwork, self).__init__()
        self.networks = networks

    def forward(self, x, network_indices):
        outputs = []
        for i in range(x.size(0)):
            network_input = x[i]
            # only integer tensors of a single element can be converted to an index
            network_output = self.networks[network_indices[i]](network_input)
            outputs.append(network_output)
        return torch.stack(outputs)

networks = nn.ModuleList([SingleVariableNetwork(i) for i in range(5)]) 
indexedNetwork = IndexedNetwork(networks)
input = torch.tensor([1, 1, 1, 1, 1])
indices = torch.tensor([3, 0, 2, 1, 4])
result = indexedNetwork(input, indices)
print(result)

我尝试像 IndexedNetwork 类的前向函数那样写:

return self.experts[expert_index](expert_input)
但是 pytorch 会引发错误“只有单个元素的整数张量可以转换为索引”

pytorch vectorization
1个回答
0
投票

您无法沿

nn.ModuleList
广播输入。模块列表只是一个带有参数跟踪额外功能的 Python 列表。

如果您的子网络都是相同的架构,您应该能够在张量操作级别对输入进行矢量化。例如,将

self.v
参数更改为向量(代表子模型),并将输入
x
更改为向量。

© www.soinside.com 2019 - 2024. All rights reserved.