我正在学习机器学习课程,但在使用给定代码解决问题时遇到一些困难。
import tensorflow as tf
import tensorflow_datasets as tfds
(data), info = tfds.load("iris", with_info=True, split="train")
print(info.splits)
data = data.shuffle(150)
train_data = data.take(120)
test_data = data.skip(120)
def preprocess(dataset):
def _preprocess_img(image, label):
label = tf.one_hot(label, depth=3)
return image, label
dataset = dataset.map(_preprocess_img)
return dataset.batch(32).prefetch(tf.data.experimental.AUTOTUNE)
train_data = preprocess(train_data)
test_data = preprocess(test_data)
这只是一个代码片段,但它应该涵盖这里的问题区域。我收到错误消息:TypeError:outer_factory..inner_factory..tf___preprocess_img()缺少1个必需的位置参数:'label'
我无法解决这个问题,有人知道这里出了什么问题吗?我的意思是,是的,该函数需要标签参数,但在其他示例中我看到它似乎有效。但我想知道数据集的解包是否无法按预期工作?
我尝试的是更改要映射的函数,我查看了数据集的元素,但这确实无助于我获得正确的见解。我也在寻找其他示例,但我在这里看不出这个特定代码有什么问题。
import tensorflow as tf
import tensorflow_datasets as tfds
(train_data, test_data), info = tfds.load("iris", with_info=True, split=["train[:120]", "train[120:]"], as_supervised=True)
def preprocess(image, label):
# Perform one-hot encoding for the labels
label = tf.one_hot(label, depth=3)
return image, label
train_data = train_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
test_data = test_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
print("Training data splits:", train_data.cardinality())
print("Testing data splits:", test_data.cardinality())
for batch in train_data.take(1):
images, labels = batch
print("Batch shape:", images.shape)