我要制作多标签数据集吗?

问题描述 投票:0回答:1

我正在尝试创建一个tf2.0数据集,其中包含已经为np.array格式的图像以及每个图像7个不同的标签。我的代码是dataset=tf.data.Dataset.from_tensor_slices((data,labelZh,labelCh1,labelCh2,labelCh3,labelCh4,labelCh5,labelCh6))


我是否正确使用tf.data.Dataset.from_tensor_slice?当我将数据集放入网络中时,出现错误:TypeError: in converted code: TypeError: map_fn() takes from 1 to 3 positional arguments but 8 were given

关注我的网络:

def Forward():
input=keras.Input(shape=(64,64,3),name='title')
x=layers.Conv2D(8,3,activation="relu",padding='same',kernel_initializer="he_normal")(input)
x=layers.MaxPooling2D(2)(x)
x=layers.Dropout(0.25)(x)
x=layers.Conv2D(16,3,activation="relu",padding='same',kernel_initializer="he_normal")(x)
x=layers.MaxPooling2D(2)(x)
x=layers.Conv2D(32,3,activation="relu",padding='same',kernel_initializer="he_normal")(x)
x=layers.MaxPooling2D(2)(x)
x=layers.Conv2D(64,3,activation="relu",padding='same',kernel_initializer="he_normal")(x)
x=layers.MaxPooling2D(2)(x)
x=layers.Dropout(0.25)(x)
output_Zh=layers.Dense(32)(x)
output_1=layers.Dense(34)(x)
output_2=layers.Dense(34)(x)
output_3=layers.Dense(34)(x)
output_4=layers.Dense(34)(x)
output_5=layers.Dense(34)(x)
output_6=layers.Dense(34)(x)
model=keras.Model(inputs=input,outputs=[output_Zh,output_1,output_2,output_3,output_4,output_5,output_6])

return model
python tensorflow tensorflow2.0 tensorflow-datasets
1个回答
0
投票

我现在无法测试代码,但这能解决您的问题:

dataset=tf.data.Dataset.from_tensor_slices((
    data,
    {
        'output_Zh': labelZh,
        'output_1': labelCh1,
        'output_2': labelCh2,
        'output_3': labelCh3,
        'output_4': labelCh4,
        'output_5': labelCh5,
        'output_6': labelCh6
    }
))
© www.soinside.com 2019 - 2024. All rights reserved.