具有非常大的HDF5文件的Tensorflow-IO数据集输入管道

问题描述 投票:1回答:1

我有非常大的训练(30Gb)文件。由于所有数据都不适合我的可用RAM,因此我想分批读取数据。我看到有Tensorflow-io软件包,其中implemented a way通过功能tfio.IODataset.from_hdf5()将HDF5以这种方式读取到Tensorflow中然后,由于tf.keras.model.fit()tf.data.Dataset作为包含样本和目标的输入,因此我需要将X和Y压缩在一起,然后使用.batch and .prefetch仅将必要的数据加载到内存中。对于测试,我尝试将此方法应用于较小的样本:训练(9Gb),验证(2.5Gb)和测试(1.2Gb),我知道它们很好,因为它们可以装入内存,并且我得到了很好的结果(70%的准确性和< 1损失)。训练文件存储在HDF5文件中,分成以下示例文件:(X)和标签(Y):

X_learn.hdf5  
X_val.hdf5  
X_test.hdf5  
Y_test.hdf5  
Y_learn.hdf5  
Y_val.hdf5

这是我的代码:

BATCH_SIZE = 2048
EPOCHS = 100

# Create an IODataset from a hdf5 file's dataset object  
x_val = tfio.IODataset.from_hdf5(path_hdf5_x_val, dataset='/X_val')
y_val = tfio.IODataset.from_hdf5(path_hdf5_y_val, dataset='/Y_val')
x_test = tfio.IODataset.from_hdf5(path_hdf5_x_test, dataset='/X_test')
y_test = tfio.IODataset.from_hdf5(path_hdf5_y_test, dataset='/Y_test')
x_train = tfio.IODataset.from_hdf5(path_hdf5_x_train, dataset='/X_learn')
y_train = tfio.IODataset.from_hdf5(path_hdf5_y_train, dataset='/Y_learn')

# Zip together samples and corresponding labels
train = tf.data.Dataset.zip((x_train,y_train)).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
test = tf.data.Dataset.zip((x_train,y_train)).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
val = tf.data.Dataset.zip((x_train,y_train)).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

# Build the model
model = build_model()

# Compile the model with custom learing rate function for Adam optimizer
model.compile(loss='categorical_crossentropy',
               optimizer=Adam(lr=lr_schedule(0)),
               metrics=['accuracy'])

# Fit model with class_weights calculated before
model.fit(train,
          epochs=EPOCHS,
          class_weight=class_weights_train,
          validation_data=val,
          shuffle=True,
          callbacks=callbacks)

此代码运行,但损失非常高(300+),准确度从一开始就下降到0(0.30-> 4 * e ^ -5)...我不明白我在做什么错,是我错过了什么?

python tensorflow machine-learning tensorflow2.0 tensorflow-datasets
1个回答
0
投票

此处提供解决方案(答案部分),即使它存在于注释部分中也是为了社区的利益。

代码没有问题,实际上是数据(没有适当地预处理),因此模型无法很好地学习,从而导致奇怪的损失和准确性。

© www.soinside.com 2019 - 2024. All rights reserved.