多类数据集不平衡

问题描述 投票:0回答:1
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf

train_path = 'Skin/Train'
test_path = 'Skin/Test'

train_gen = ImageDataGenerator(rescale=1./255)
train_generator = train_gen.flow_from_directory(train_path,target_size= 
                                         (300,300),batch_size=30,class_mode='categorical')

model = tf.keras.models.Sequential([
# Note the input shape is the desired size of the image 300x300 with 3 bytes color
# This is the first convolution
tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(600, 450, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
# The second convolution
tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# The third convolution
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# The fourth convolution
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# The fifth convolution
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# Flatten the results to feed into a DNN
tf.keras.layers.Flatten(),
# 512 neuron hidden layer
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(9, activation='softmax')
])

from tensorflow.keras.optimizers import RMSprop

model.compile(loss='categorical_crossentropy',
          optimizer=RMSprop(lr=0.001),
          metrics=['acc'])

history = model.fit_generator(
  train_generator,
  steps_per_epoch=8,  
  epochs=15,
  verbose=2, class_weight = ? )

我在获得准确性方面存在问题,我正在训练一个9类数据集,其中1、4和5类仅具有100、96、90个图像,而其余类具有500个以上图像。因此,由于权重偏向数量更多的图像,因此我无法获得更高的精度。我希望在训练期间所有类都被认为相等,即500。如果我可以通过tensorflow或任何keras函数代码对类进行上采样,将不胜感激。而不是手动对文件夹中的图像进行升采样或降采样。

keras tensorflow2.0 tensorflow-datasets
1个回答
0
投票

您可以在fit方法中使用class_weight参数。对于上采样,这是不可避免的,需要大量的手工工作。

假设您有一个形状为(anything, 9)的输出,并且您知道每个类的总数:

totals = np.array([500,100,500,500,96,90,.......])
totalMean = totals.mean()
weights = {i: totalMean / count for i, count in enumerate(totals)]

model.fit(....., class_weight = weights)
© www.soinside.com 2019 - 2024. All rights reserved.