如何在 tflite 中使用 c++ api 获得权重?

问题描述 投票:1回答:1

我在设备上使用一个.tflite模型,最后一层是ConditionalRandomField层,我需要该层的权重来做预测。最后一层是ConditionalRandomField层,我需要该层的权重来做预测。如何用c++ api获取权重?

相关的。如何在.tflite文件中查看权重?

Netron或者flatc不能满足我的需求,对设备太重了。

似乎TfLiteNode将权重存储在void* user_data或void* builtin_data中。我如何读取它们?

更新。

结论:.tflite不存储CRF权重,而.h5则存储。(也许是因为它们不影响输出。)

我做了什么。

// obtain from model.
Interpreter *interpreter;
// get the last index of nodes.
// I'm not sure if the index sequence of nodes is the direction which tensors or layers flows.
const TfLiteNode *node = &((interpreter->node_and_registration(interpreter->nodes_size()-1))->first);

// then follow the answer of @yyoon
tensorflow-lite tf-lite
1个回答
0
投票

在TFLite节点中,权重应该存储在 inputs 数组,其中包含相应的 TfLiteTensor*.

所以,如果你已经获得了 TfLiteNode* 的最后一层,你可以这样做来读取权重值。

TfLiteContext* context; // You would usually have access to this already.
TfLiteNode* node;       // <obtain this from the graph>;

for (int i = 0; i < node->inputs->size; ++i) {
  TfLiteTensor* input_tensor = GetInput(context, node, i);

  // Determine if this is a weight tensor.
  // Usually the weights will be memory-mapped read-only tensor
  // directly baked in the TFLite model (flatbuffer).
  if (input_tensor->allocation_type == kTfLiteMmapRo) {
    // Read the values from input_tensor, based on its type.
    // For example, if you have float weights,
    const float* weights = GetTensorData<float>(input_tensor);

    // <read the weight values...>
  }
}
© www.soinside.com 2019 - 2024. All rights reserved.