Java Tensorflow推断与冻结的CNN模型有关

问题描述 投票:0回答:1

我遇到了一些有关使用Java Tensorflow API的问题。

基本上,我试图使用我在Python中训练的冻结模型来预测一些图像,但我想用Java中的Tensorflow进行这些推断,以便我将在稍后开发的一些应用程序,如果这样做的话。

我开始将我的Python模型导出为.pb文件,然后可以将其加载到Tensorflow中,它可以用于推理目的,我在Python中测试它并没有任何问题。

然后,我尝试修改可在GitHub(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java)上找到的Java Tensorflow示例中提供的LabelImage.java示例。我基本上修改了模型的路径和我将使用的图像。在成功纠正一些错误后,代码是可运行的,但是我收到了这个错误:

Exception in thread "main" java.lang.UnsupportedOperationException: Generic conv implementation does not support grouped convolutions for now.
 [[{{node conv2d_1/convolution}} = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], padding="SAME", strides=[1, 1, 1, 1], use_cudnn_on_gpu=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_1_0_0, conv2d_1/kernel)]]

我一般都是Java和Tensorflow的新手,我试图找到类似的错误,例如我得到的错误,我没有找到任何有用的东西。我想知道错误是否试图告诉我当前的Tensorflow API for Java不支持卷积,如果是这种情况我会很惊讶。无论如何,我对于解决这个问题我能做些什么有点迷茫,我希望有人可以帮我解决问题。

一些细节:我在Keras上构建并训练了一个U-Net模型,并使用Gi​​tHub上某个用户的方法将训练有素的Keras模型转换为.pb文件,该文件可以在Tensorflow上重新加载并运行以进行推理(用户:https://github.com/amir-abdi/keras_to_tensorflow)。这个重新加载和推理部分在Python中完美运行(我测试它是肯定的)。

错误似乎发生在这个代码块中:

 private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
try (Graph g = new Graph()) {
  g.importGraphDef(graphDef);
  try (Session s = new Session(g);
      // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
      Tensor<Float> result =
          s.runner().feed("input_1", image).fetch("conv2d_24/Sigmoid").run().get(0).expect(Float.class)) {
    final long[] rshape = result.shape();
    if (result.numDimensions() != 2 || rshape[0] != 1) {
      throw new RuntimeException(
          String.format(
              "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
              Arrays.toString(rshape)));
    }
    int nlabels = (int) rshape[1];
    return result.copyTo(new float[1][nlabels])[0];
  }
}

这段代码没有改变,正如我所说,我刚刚更改了指向我的模型的输入路径和用于测试的示例图像。我改变的确切部分可以在下面找到:

  public static void main(String[] args) throws Exception {
System.out.println("TensorFlow version: " + TensorFlow.version());

byte[] graphDef = readAllBytesOrExit(Paths.get("C:\\Users\\joao_\\Documents\\GitHub\\Tensorflow-to-PB\\java_code\\src\\main\\resources\\test.pb"));
byte[] imageBytes = readAllBytesOrExit(Paths.get("C:\\Users\\joao_\\Documents\\GitHub\\Tensorflow-to-PB\\java_code\\src\\main\\resources\\02.png"));

try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
  float[] labelProbabilities = executeInceptionGraph(graphDef, image);
  int bestLabelIdx = maxIndex(labelProbabilities);
}

我希望这些信息足以理解问题。

java python tensorflow keras deep-learning
1个回答
0
投票

好吧,最后我找到了自己问题的答案。

基本上,错误与我将图像输入到没有合适尺寸的模型这一事实有关(我的图像是512x512而我的模型只需要256x256图像)。所以,我猜问题是输入张量没有正确的尺寸。

希望通过帮助人们解决同样的问题,这篇文章仍然有用。

© www.soinside.com 2019 - 2024. All rights reserved.