[我一直在尝试利用各种灵感-特别是this one-创建带标签的图像数据集以传递给model.fit()。
我的代码似乎等同于该问题在the answer中给出的代码... _parse_function()
与该问题的OP稍有不同:
def load_image( path, label ):
file_contents = tf.io.read_file( path )
image = tf.image.decode_image( file_contents )
image = tf.image.convert_image_dtype( image, tf.float32 )
return image, label
我可以在python命令行中独立测试此功能,例如image, label = load_image( "tiger.jpg", "Tiger" )
并以"Tiger"
标签和正确对应于图像左上像素的image[0][0]
结尾:
>>> image[0][0]
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.37254903, 0.5529412 , 0.854902 ], dtype=float32)>
同样,如果我在程序中尝试print( image[ 0 ][ 0 ]
,则会得到:
tf.Tensor([0.37254903 0.5529412 0.854902 ], shape=(3,), dtype=float32)
我是python的新手,所以我希望这些只是主题上的等效变化,但是无论哪种方式,当我将所有内容传递给程序中的model.fit()
时,我都会得到:
ValueError: Cannot take the length of shape with unknown rank.
任何主题的变化都没有使我超越这一点。我从数据集中消除了所有管道操作(例如,没有.shuffle()
,没有.repeat()
,没有.batch()
),因此我只使用了.map()
函数,并得到相同的错误结果。我可以看到的唯一错误可能是在上面的load_image()
函数中,或在调用代码中:
dataset = tf.data.Dataset.from_tensor_slices( ( images, labels ) ) # tf.constant() does not change error
dataset = dataset.map( load_map )
model.fit( dataset, epochs=100 )
导致错误的原因是什么?
decode_image
存在一个已知的问题-无法正确设置形状信息(请参阅here。您可以使用更具体的调用-例如decode_jpeg
或decode_png
。
此外...您将遇到的下一个问题是,您不能直接使用“老虎”之类的标签。如果“老虎”在诸如““狮子”,“老虎”,“斑马”,“猿”,...]等类别的列表中,则您需要在此类列表中使用“老虎”的索引(即1
)或一键式表示(即[False,True,False,False,...]
)