tensorflow DenseHashTable查找多维键

问题描述 投票:0回答:1

我想使用DenseHashTable查找字符串张量,就像这个答案answer一样,键的类型是tf.string,值嵌入tf.float32 dtype。但是当键是多维时,就会出现错误。

keys = ["Fritz", "Franz", "Fred"]
values = [[1, 2, 3, -1], [4, 5, -1, -1], [6, 7, 8, 9]]
table = tf.lookup.experimental.DenseHashTable(key_dtype=tf.string, value_dtype=tf.float32, empty_key="0", deleted_key="-1", default_value=[-1,-1,-1,-1])
table.insert(keys, values)
table.lookup(['Franz', 'Emil']) # shape=(2,) its ok
table.lookup([['Franz', 'Emil'], ['Emil', 'Fred']]) # when lookup with 2-D tensor(shape like (batch_size, 2)), throws error.

我怎样才能让它像 tf.nn.embedding_lookup 一样工作?键不是数组索引而是 tf.string。

tensorflow tensorflow2.0 recommendation-engine embedding tensorflow-estimator
1个回答
0
投票

问题是 TensorFlow 需要一个键列表,而不是嵌套的键列表。当然,

docs
中按键描述中的 Can be a tensor of any shape. 有点令人困惑。
你能做的就是展平你的列表,对其进行散列,然后重新调整它的形状:

keys = [['Franz', 'Emil'], ['Emil', 'Fred']]
keys = tf.convert_to_tensor(keys)  # to get the shape
key_shape = keys.shape  # shape: (2, 2)
x = table.lookup(tf.reshape(keys, -1))  # shape: (4, 4) after hashing

x = tf.reshape(x, key_shape+(x.shape[-1:]))  # shape: (2, 2 ,4)
© www.soinside.com 2019 - 2024. All rights reserved.