在tf.data输入管道中使用tf.function构建自定义地图函数

问题描述 投票:1回答:1

我正在尝试用Python为tensorflow tf.data输入管道编写带tf.function注释的地图函数。

函数应将字符串转换为单次热编码张量。输入字符串的格式为[ab12] +。(实际上,字符串中有更多的字符和数字,但是对于下面的示例来说已经足够了。)

这里是一个最小的示例:

DIM = 100
DIM_A = 1
DIM_B = 2

pos = tf.Variable(0, dtype=tf.int32)

@tf.function
def my_func(string):
  output = np.zeros(DIM * 10, dtype=np.float32)
  pos.assign(0)
  for ch in tf.strings.bytes_split(string):
    if tf.math.equal(ch, tf.constant("1")):
        pos.assign_add(1)
    elif tf.math.equal(ch, tf.constant("2")):
        pos.assign_add(2)
    elif tf.math.equal(ch, tf.constant("a")):
        output[DIM_A + DIM * pos] = 1
        pos.assign_add(1)
    elif tf.math.equal(ch, tf.constant("b")):
        output[DIM_B + DIM * pos] = 1
        pos.assign_add(1)
  return output

s = b"a1b2b"
print(my_func(s))

尝试计算在输出张量中设置1的位置的索引,我得到以下错误:

NotImplementedError:在用户代码中:

<ipython-input-14-baa9b1605ae2>:18 my_func  *
    output[DIM_A + DIM * pos] = 1
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:749 __array__
    " array.".format(self.name))

NotImplementedError: Cannot convert a symbolic Tensor (add:0) to a numpy array.

代码在渴望模式下工作,但在构建图形时会中断。

我有一个工作版本,该版本使用动态大小的TensorArray来首先构建输出张量的稀疏版本,然后将其转换为密集张量,但这确实很慢。固定大小的TensorArray而不是numpy数组也非常慢。我正在尝试使其更快。

tensorflow tensorflow2.0 tensorflow-datasets
1个回答
0
投票

1]在图形模式下不能使用numpy,因此output应该是tf.zeros而不是np.zeros

2)您无法分配给tf.zeros Tensor,因此您应该只使用tf.one_hot从头开始构建即可。

最小工作示例:

import tensorflow as tf
import numpy as np 

DIM = 100
DIM_A = 1
DIM_B = 2

pos = tf.Variable(0, dtype=tf.int32)

@tf.function
def my_func(string):
  output = tf.zeros(DIM * 10, dtype=tf.float32)
  pos.assign(0)
  for ch in tf.strings.bytes_split(string):
    if tf.math.equal(ch, tf.constant("1")):
        pos.assign_add(1)
    elif tf.math.equal(ch, tf.constant("2")):
        pos.assign_add(2)
    elif tf.math.equal(ch, tf.constant("a")):
        output = tf.one_hot(DIM_A + DIM * pos, DIM * 10, dtype=tf.float32)
        pos.assign_add(1)
    elif tf.math.equal(ch, tf.constant("b")):
        output = tf.one_hot(DIM_B + DIM * pos, DIM * 10, dtype=tf.float32)
        pos.assign_add(1)
  return output

s = b"a1b2b"
print(my_func(s).numpy())

此功能打印一个热编码矢量。我不知道索引是否正是您想要的索引,因此您必须仔细检查偏移量是否正确。

© www.soinside.com 2019 - 2024. All rights reserved.