Python - Tensorflow:如何将函数正确映射到数据集

问题描述 投票:0回答:1

我正在学习机器学习课程,但在使用给定代码解决问题时遇到一些困难。

import tensorflow as tf
import tensorflow_datasets as tfds

(data), info = tfds.load("iris", with_info=True, split="train")
print(info.splits)

data = data.shuffle(150)
train_data = data.take(120)
test_data = data.skip(120)

def preprocess(dataset):
    
    def _preprocess_img(image, label):
        label = tf.one_hot(label, depth=3)
        return image, label
            
    dataset = dataset.map(_preprocess_img)
    return dataset.batch(32).prefetch(tf.data.experimental.AUTOTUNE)

train_data = preprocess(train_data)
test_data = preprocess(test_data)

这只是一个代码片段,但它应该涵盖这里的问题区域。我收到错误消息:TypeError:outer_factory..inner_factory..tf___preprocess_img()缺少1个必需的位置参数:'label'

我无法解决这个问题,有人知道这里出了什么问题吗?我的意思是,是的,该函数需要标签参数,但在其他示例中我看到它似乎有效。但我想知道数据集的解包是否无法按预期工作?

我尝试的是更改要映射的函数,我查看了数据集的元素,但这确实无助于我获得正确的见解。我也在寻找其他示例,但我在这里看不出这个特定代码有什么问题。

python function dictionary tensorflow dataset
1个回答
0
投票

导入模块

import tensorflow as tf
import tensorflow_datasets as tfds

加载 Iris 数据集并将其分为训练集和测试集

(train_data, test_data), info = tfds.load("iris", with_info=True, split=["train[:120]", "train[120:]"], as_supervised=True)

定义预处理函数

def preprocess(image, label):
    # Perform one-hot encoding for the labels
    label = tf.one_hot(label, depth=3)
    return image, label

将预处理函数映射到数据集并对它们进行批处理

train_data = train_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
test_data = test_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

您可以选择验证分割和第一批数据

print("Training data splits:", train_data.cardinality())
print("Testing data splits:", test_data.cardinality())

迭代数据批次(用于演示)

for batch in train_data.take(1):
    images, labels = batch
    print("Batch shape:", images.shape)
© www.soinside.com 2019 - 2024. All rights reserved.