tf.data.Dataset.map()用于由多个切片组成的数据集

问题描述 投票:0回答:1

从单个切片创建的数据集的tf.data.Dataset.map()看起来像dataset.map(lambda x: x/2)。如果数据集是由两个切片创建的,它将是什么样?参见,例如,以下代码。代码最后一行中的map()函数将适用于从单个切片创建的数据集,但会导致我的两切片情况出错。

import tensorflow as tf, numpy as np     # tensorflow 2.0
from tensorflow import keras as kr

dataset = tf.data.Dataset.from_tensor_slices((features_int8, labels_int8)) # features, labels are numpy arrays

model = kr.Sequential()
model.add(kr.layers.InputLayer(6)
model.add(kr.layers.Dense(     8, activation=tf.nn.tanh))
model.add(kr.layers.Dense(     3, activation=tf.nn.tanh))

model.compile(optimizer = kr.optimizers.RMSprop(), loss = kr.losses.MeanSquaredError())

model.fit(dataset.batch(64).map(lambda x: x/9), epochs = 10)
python tensorflow keras tensorflow-datasets
1个回答
0
投票

将lambda函数传递到如图所示的单独函数中

def map_fn(x, y):
  return x / 9, y

model.fit(dataset.batch(64).map(map_fn), epochs = 10)
© www.soinside.com 2019 - 2024. All rights reserved.