将一个热编码从文件提取到数据集中

问题描述 投票:2回答:2

我有一个数据集图像和相应的标签,每个图像文件的位置都有一个.txt文件,其中包含一个热编码:

0
0
0
0
1
0

我的代码看起来像这样:

imageString = tf.read_file('image.jpg')
imageDecoded = tf.image.decode_jpeg(imageString)

labelString = tf.read_file(labelPath)
# decode csv string

但labelString看起来像这样:

tf.Tensor(b'0\n0\n0\n0\n1\n', shape=(), dtype=string)

有没有办法将其转换为张量流内的数字数组?

python tensorflow tensorflow-datasets
2个回答
1
投票

这是一个功能。

import tensorflow as tf

def read_label_file(labelPath):
    # Read file
    labelStr = tf.io.read_file(labelPath)
    # Split string (returns sparse tensor)
    labelStrSplit = tf.strings.split([labelStr])
    # Convert sparse tensor to dense
    labelStrSplitDense = tf.sparse.to_dense(labelStrSplit, default_value='')[0]
    # Convert to numbers
    labelNum = tf.strings.to_number(labelStrSplitDense)
    return labelNum

一个测试用例:

import tensorflow as tf

# Write file for test
labelPath = 'labelData.txt'
labelTxt = '0\n0\n0\n0\n1\n0'
with open(labelPath, 'w') as f:
    f.write(labelTxt)
# Test the function
with tf.Session() as sess:
    label_data = read_label_file(labelPath)
    print(sess.run(label_data))

输出:

[0. 0. 0. 0. 1. 0.]

注意这个函数,正如我所写的那样,它使用了一些新的API端点,你也可以把它写成下面的更多向后兼容性,几乎相同的意思(tf.strings.splittf.string_split之间有细微的差别):

import tensorflow as tf

def read_label_file(labelPath):
    labelStr = tf.read_file(labelPath)
    labelStrSplit = tf.string_split([labelStr], delimiter='\n')
    labelStrSplitDense = tf.sparse_to_dense(labelStrSplit.indices,
                                            labelStrSplit.dense_shape,
                                            labelStrSplit.values, default_value='')[0]
    labelNum = tf.string_to_number(labelStrSplitDense)
    return labelNum

0
投票

您可以使用基本的python命令并将其转换为张量。尝试...

with open(labelPath) as f:
    lines = f.readlines()
    lines = [int(l.strip()) for l in lines if l.strip()]
labelString = tf.convert_to_tensor(lines, dtype='int32')
© www.soinside.com 2019 - 2024. All rights reserved.