如何使用包含Tensorflow会话的函数映射数据集?

问题描述 投票:0回答:1

我有一个Tensorflow DatasetV1Adapter对象形式的数据集。

<DatasetV1Adapter shapes: OrderedDict([(labels, (6,)), (snippets, ())]), types: OrderedDict([(labels, tf.int32), (snippets, tf.string)])>

# Example Output
OrderedDict([('labels', <tf.Tensor: id=37, shape=(6,), dtype=int32, numpy=array([0, 0, 0, 0, 0, 0], dtype=int32)>), ('snippets', <tf.Tensor: id=38, shape=(), dtype=string, numpy=b'explanationwhy the edits made under my username hardcore metallica fan were reverted they werent vandalisms just closure on some gas after i voted at new york dolls fac and please dont remove the template from the talk page since im retired now892053827'>)])

OrderedDict([('labels', <tf.Tensor: id=41, shape=(6,), dtype=int32, numpy=array([0, 0, 0, 0, 0, 0], dtype=int32)>), ('snippets', <tf.Tensor: id=42, shape=(), dtype=string, numpy=b'daww he matches this background colour im seemingly stuck with thanks  talk 2151 january 11 2016 utc'>)])

如你所见,它包含一个OrderedDict对象,其中包含labelssnippets的键。后者基本上是重要的,因为它包含我希望使用句子嵌入转换为向量的文本字符串。

为此,我决定使用tensorflow hub的Universal Sentence Encoder(USE)。它基本上接受一个句子列表作为输入,并输出一个长度为512的向量作为输出。需要注意的一点是,如果启用了预先执行,则无法执行tensorflow hub。因此,我们必须定义一个能够与tensorflow hub一起使用USE的会话。

但是,我希望使用tensorflow提供的map。但问题是,我应该如何定义一个在其中有张量流会话的函数?要使用该功能并将其映射到数据集,我是否需要定义另一个张量流会话?

我的第一个方法是实际做到这一点。具体而言,通过定义包含张量流会话的函数。然后,启动新的tensorflow会话并尝试将该函数映射到该会话中的该数据集。

请注意,我在会话之外定义了USE语句嵌入模型。

# Sentence embedding model (USE)
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")

def to_vec(w):
    x = w['snippets']
    with tf.Session() as sess:
        vector = sess.run(embed(x))
    return vector

with tf.Session() as sess:
        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
        # try_data is the DatasetV1Adapter object
        sess.run(try_data.map(to_vec))

但我最终得到了这个错误

RuntimeError: Module must be applied in the graph it was instantiated for.

或者,我尝试在tensorflow会话中定义函数,就像这样

with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

    def to_vec(w):
        x = w['snippets']
        vector = sess.run(embed(x))
        return vector
    sess.run(try_data.map(to_vec))

但这没有用,我仍然得到同样的错误。在做了一些搜索之后,我偶然发现了this postthis post,我说我必须定义一个tf.Graph并在会话中传递它。

graph = tf.Graph()

with graph.as_default():
    with tf.Session(graph=graph) as sess:
        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
        def to_vec(w):
            x = w['snippets']
            vector = sess.run(embed(x))
            return vector

        sess.run(try_data.map(to_vec))

然而,我仍然收到同样的错误。我还尝试在会话中定义USE,但仍然会导致相同的错误。

从那里开始,我对如何做到这一点感到非常困惑。有没有人对我错过的东西有任何想法?提前致谢。

python tensorflow tensorflow-datasets tensorflow-hub
1个回答
0
投票

简短的回答:你没有。 Tensorflow将调用您在图形模式下传递给Dataset.map的函数(它只调用一次函数并使用每个示例生成的图形,因此您可能不必担心可能正在运行与集线器相关的准备工作(下载等)每个例子)。

我对tensorflow hub并不过分熟悉,但请尝试以下方法。

def map_fn(inputs):
    snippets = inputs['snippets']
    # you -may- be able to pull the line below outside of map_fn
    # it probably won't affect performance
    embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
    vector = embed(snippets)
    return vector


dataset = dataset.map(map_fn)
© www.soinside.com 2019 - 2024. All rights reserved.