为什么out_channels必须大于channel_multiplier * in_channels在pointwise_filter是tf.nn.separable_conv2d的参数?

问题描述 投票:0回答:2

请首先查看代码:

import tensorflow as tf
import numpy as np
input_data = tf.Variable(np.random.rand(10,9,9,3), dtype=np.float32)
#depthwise_filter (filter_height, filter_width, in_channels, channel_multiplier)
depthwise_filter = tf.Variable(np.random.rand(2,2,3,5), dtype=np.float32)
#pointwise_filter (1, 1, channel_multiplier * in_channels, out_channels)
pointwise_filter = tf.Variable(np.random.rand(1,1,15,8), dtype=np.float32)
y = tf.nn.separable_conv2d(input_data, depthwise_filter, pointwise_filter, strides=[1,1,1,1], padding='SAME')
print(tf.shape(y))

错误:

Traceback (most recent call last):
  File "tsfl.py", line 36, in <module>
    y = tf.nn.separable_conv2d(input_data, depthwise_filter, pointwise_filter, strides=[1,1,1,1], padding='SAME')
  File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\ops\nn_impl.py", line 486, in separable_conv2d
    channel_multiplier * in_channels, out_channels))
ValueError: Refusing to perform an overparameterized separable convolution: channel_multiplier * in_channels = 5 * 3 = 15 > 8 = out_channels

当channel_multiplier * in_channels <= out_channels,它的工作原理。

为什么out_channels必须大于channel_multiplier * in_channels?

tensorflow
2个回答
0
投票

这个限制已在当前Tensorflow主分支被删除。见Tensorflow: What does tf.nn.separable_conv2d do?了解更多详情。


0
投票

这个问题源于Tensorflow版本,我已经升级TF从1.1.01.10.0和它的工作。

pip install tensorflow==1.10.0
© www.soinside.com 2019 - 2024. All rights reserved.