如何将自定义动态时间扭曲器与 TensorFlow 2.5.0 集成?

问题描述 投票:0回答:1

我目前正在开发一个项目,在该项目中我用 Python 实现了自定义动态时间规整 (DTW) 算法。现在,我想将此自定义 DTW 与 TensorFlow 2.5.0 集成,特别是在 RNN 模型的自定义层中。 TensorFlow 文档没有涵盖这个特定场景,我也找不到任何讨论这个问题的资源或示例。谁能指导如何做到这一点?

这是我的自定义 DTW 算法的 Python 代码:

import numpy as np

def custom_dtw(seq1, seq2):
    # DTW algorithm implementation here...
    pass

我希望在自定义 TensorFlow 层中使用它,如下所示:

import tensorflow as tf

class CustomDTWLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(CustomDTWLayer, self).__init__()

    def call(self, inputs):
        # Use custom_dtw here...
        pass

如果有任何帮助或正确方向的指示,我将不胜感激。谢谢!

tensorflow2.0 tensorflow-lite
1个回答
0
投票

要将自定义动态时间扭曲算法与 TensorFlow 集成,您需要将 DTW 函数包装在 tf.py_function 中,这样您就可以在 TensorFlow 图中运行任意 Python 代码。以下是如何执行此操作:

import tensorflow as tf

class CustomDTWLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(CustomDTWLayer, self).__init__()

    def call(self, inputs):
        result = tf.py_function(custom_dtw, [inputs], tf.float32)
        return result

在此代码中,custom_dtw 是您的 DTW 函数,[inputs] 是函数的张量输入列表,tf.float32 是函数的输出类型。

注意:由于 tf.py_function 在 TensorFlow 图之外运行,因此它无法受益于 GPU 加速,并且不会自动计算其梯度。

参考:

有关 tf.py_function 的 TensorFlow 文档:https://www.tensorflow.org/api_docs/python/tf/py_function

我希望这有帮助!

© www.soinside.com 2019 - 2024. All rights reserved.