Dataset.shuffle()之后批处理中的部分随机元素

问题描述 投票:0回答:1

我正在TF2.1中使用tf.data.Dataset和tf.keras来训练数据集。但是我看到一个奇怪的行为,即所生成的批次没有完全符合我的预期。我的意思是,即使我的数据集有4个类,我通常也只能看到一批中只有2个类的元素。我的代码如下:

def process_train_sample(file_path):
  sp = tf.strings.regex_replace(file_path, train_data_dir, '')
  cls = tf.math.argmax(tf.cast(tf.math.equal(tf.strings.split(sp, os.path.sep)[0],['A','B','C','D']), tf.int64))

  img = tf.io.read_file(file_path)
  img = tf.image.decode_jpeg(img, channels=3)  # RGB
  img = tf.image.resize(img, (224, 224))
  img = tf.cast(img, tf.float32)
  img = img - np.array([123.68, 116.779, 103.939])
  img = img / 255.0
  cls = tf.expand_dims(cls, 0)
  return img, cls

train_data_list = glob.glob(os.path.join(train_data_dir, '**', '*.jpg'), recursive=True)
train_data_list = tf.data.Dataset.from_tensor_slices(train_data_list)
train_ds = train_data_list.map(process_train_sample, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = train_ds.shuffle(10000)
train_ds = train_ds.batch(batch_size)

for img,  cls in train_ds.take(10):
  print('img: ', img.numpy().shape,  'cls: ', cls.numpy())

model.compile(loss='categorical_crossentropy',  
      optimizer=optimizers.SGD(lr=0.0001, momentum=0.9),
      metrics=['categorical_accuracy', 'categorical_crossentropy'])
model.fit(train_ds, epochs=50)

[当我在A,B,C,D等4类数据集上进行训练时,我发现训练精度并没有稳定地增加,而是上下波动。然后,我通过像在for循环中逐批显示标签来检查数据输入管道,发现每个批仅包含2个类(而不是4个)中的元素。似乎数据集没有像我期望的那样被打乱,这可能会导致准确性不稳步增长。但是我看不出代码有什么问题。

tensorflow keras tensorflow2.0 tensorflow-datasets
1个回答
0
投票

.shuffle(10 000)中,10 000是缓冲区大小,这意味着它将从前10000张图像中进行采样。当您有约30 000张图像时,这只会产生第一批批次中第一类和第二类的图像。在继续训练时,您将开始从(1,2,3)类别开始采样,然后仅是(2,3),然后是(2,3,4),然后是(3,4),然后是(3,4, 1),然后是(4,1),然后是(4,1,2),然后是(1,2),然后是(1,2,3),依此类推。如果有内存,请尝试将随机缓冲区大小设置为30 000,否则,请首先将路径列表重新排序,然后使用较大的批处理大小。

© www.soinside.com 2019 - 2024. All rights reserved.