手动将LensM从Tensorflow导入PyTorch

问题描述 投票:2回答:1

我试图从tensorflow导入一个预训练模型到PyTorch。它需要一个输入并将其映射到单个输出。当我尝试导入LSTM权重时出现混乱

我使用以下函数从文件中读取权重及其变量:

def load_tf_model_weights():        

    modelpath = 'models/model1.ckpt.meta'

    with tf.Session() as sess:        
        tf.train.import_meta_graph(modelpath) 
        init = tf.global_variables_initializer()
        sess.run(init)  
        vars = tf.trainable_variables()        
        W = sess.run(vars)

    return W,vars

W,V = load_tf_model_weights()

然后我正在检查重量的形状

In [33]:  [w.shape for w in W]
Out[33]: [(51, 200), (200,), (100, 200), (200,), (50, 1), (1,)]

此外,变量定义为

In [34]:    V
Out[34]: 
[<tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(51, 200) dtype=float32_ref>,
<tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(200,) dtype=float32_ref>,
<tf.Variable 'rnn/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(100, 200) dtype=float32_ref>,
<tf.Variable 'rnn/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(200,) dtype=float32_ref>,
<tf.Variable 'weight:0' shape=(50, 1) dtype=float32_ref>,
<tf.Variable 'FCLayer/Variable:0' shape=(1,) dtype=float32_ref>]

所以我可以说W的第一个元素定义了LSTM的内核,第二个元素定义了它的偏差。根据this post,核的形状定义为[input_depth + h_depth, 4 * self._num_units],偏差定义为[4 * self._num_units]。我们已经知道input_depth1。所以我们得到,h_depth_num_units都有值50

在pytorch我的LSTMCell,我想要分配权重,看起来像这样:

In [38]: cell = nn.LSTMCell(1,50)
In [39]: [p.shape for p in cell.parameters()]
Out[39]: 
[torch.Size([200, 1]),
torch.Size([200, 50]),
torch.Size([200]),
torch.Size([200])]

前两个条目可以由W的第一个值覆盖,其形状为(51,200)。但来自Tensorflow的LSTMCell仅产生一个形状(200)的偏差,而pytorch想要其中两个

通过留下偏差,我有权重:

cell2 = nn.LSTMCell(1,50,bias=False)
[p.shape for p in cell2.parameters()]
Out[43]: [torch.Size([200, 1]), torch.Size([200, 50])]

谢谢!

python tensorflow lstm pytorch
1个回答
1
投票

pytorch使用CuDNN的LSTM底层(即使你没有CUDA,它仍然使用兼容的东西)因此它有一个额外的偏置项。

所以你可以选择两个数字,它们的总和等于1(0和1,1 / 2和1/2或其他任何东西)并将你的pytorch偏差设置为这些数字乘以TF的偏差。

pytorch_bias_1 = torch.from_numpy(alpha * tf_bias_data)
pytorch_bias_2 = torch.from_numpy((1.0-alpha) * tf_bias_data)
© www.soinside.com 2019 - 2024. All rights reserved.