AIM
使用 U-Net 模型分割肿瘤。
该数据集包含 200 个形状为 (112,192,160,3) 的 MRI 和 200 个尺寸为 (112,192,160,3) 的相应掩膜。
我做了什么
我以这种方式创建了一个数据生成器:
from sklearn.model_selection import train_test_split
#Train directories
train_and_val_img='/content/train_img'
train_and_val_img=natsort.natsorted(os.listdir(train_and_val_img))
def pathListIntoIds(dirList):
x = []
for i in range(0,len(dirList)):
x.append(dirList[i][dirList[i].rfind('/')+1:])
return x
train_and_test_ids = pathListIntoIds(train_and_val_img)
train_ids, val_ids = train_test_split(train_and_test_ids,test_size=0.4)
class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
#Initialization function of the class. Put as argument the relevant information. Also store information such as list of IDs to generate at each pass
def __init__(self, list_img_IDs, img_path, mask_path, to_fit=True, dim=(102,192,160), batch_size = 1, n_channels = 3, shuffle=True):
'Initialization'
self.dim = dim
self.batch_size = batch_size
self.list_img_IDs = list_img_IDs #list of all 'image' ids to use in the generator
self.to_fit = to_fit #True to return X (img) and Y (masks), False to return X only
self.img_path = img_path
self.mask_path = mask_path
self.n_channels = n_channels #number of image channels
self.shuffle = shuffle
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.list_img_IDs) / self.batch_size)) #The model sees the training samples at most once per epoch
#When the batch corresponding to a given index is called, the generator executes the __getitem__ method to generate it
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Find list of IDs
list_IDs_temp = [self.list_img_IDs[k] for k in indexes]
# Generate data
X = self._generate_X(list_IDs_temp)
if self.to_fit:
Y=self._generate_Y(list_IDs_temp)
return X,Y
else:
return X
#Triggered once at the very beginning and at the end of each epoch. If shuffle true, get a new order of exploration at each pass (otherwise linear exploration scheme)
#Shaffling the order in which examples are fed to the network is helpful so that batches between epochs look dissimilar. The model is more robust.
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_img_IDs))
if self.shuffle == True:
np.random.shuffle(self.indexes)
#The data generation reads np array of each example from its corresponding file ID.
def _generate_X(self, list_IDs_temp):
'Generates data containing batch_size images' # X : (n_samples, *dim, n_channels)
# Initialization
X = np.zeros((self.batch_size, *self.dim, self.n_channels))
#Generate data
for i, ID in enumerate(list_IDs_temp):
#Store samples
X[i,]=np.load(self.img_path + '/' + ID)
return X #X/np.max(X)
def _generate_Y(self, list_IDs_temp):
'Generates data containing batch_size images' # X : (n_samples, *dim, n_channels)
# Initialization
Y = np.zeros((self.batch_size, *self.dim, 4))
#Generate data
for i, ID in enumerate(list_IDs_temp):
#Store samples
Y[i,]=np.load(self.mask_path + '/' + ID)
return Y
# Parameters
train_img_path='/content/train_img'
train_mask_path='/content/train_mask'
params = {'dim': (112,192,160),
'batch_size': 1,
'n_channels': 3,
'to_fit': True,
'shuffle': True}
training_generator = DataGenerator(train_ids, train_img_path, train_mask_path,**params)
valid_generator = DataGenerator(val_ids, train_img_path, train_mask_path,**params)
通过这种方式,我有 120 张图片用于训练,80 张图片用于验证。
那么我用过的U-net是:
def simple_unet_model(IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS, num_classes):
#Build the model
inputs = Input((IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
s = Lambda(lambda x: x / 255)(inputs) #No need for this if we normalize our inputs beforehand
s = inputs
#Contraction path
c1 = Conv3D(16, (3, 3, 3), activation='relu', padding='same')(s)
c1 = BatchNormalization()(c1)
c1 = Dropout(0.1)(c1)
c1 = Conv3D(16, (3, 3, 3), activation='relu', padding='same')(c1)
c1 = BatchNormalization()(c1)
p1 = MaxPooling3D((2, 2, 2))(c1)
c2 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(p1)
c2 = BatchNormalization()(c2)
c2 = Dropout(0.1)(c2)
c2 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(c2)
c2 = BatchNormalization()(c2)
p2 = MaxPooling3D((2, 2, 2))(c2)
c3 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(p2)
c3 = BatchNormalization()(c3)
c3 = Dropout(0.2)(c3)
c3 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(c3)
c3 = BatchNormalization()(c3)
p3 = MaxPooling3D((2, 2, 2))(c3)
c4 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(p3)
c4 = BatchNormalization()(c4)
c4 = Dropout(0.2)(c4)
c4 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(c4)
c4 = BatchNormalization()(c4)
p4 = MaxPooling3D(pool_size=(2, 2, 2))(c4)
c5 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(p4)
c5 = BatchNormalization()(c5)
c5 = Dropout(0.3)(c5)
c5 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(c5)
c5 = BatchNormalization()(c5)
#Expansive path
u6 = Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(c5)
u6 = concatenate([u6, c4])
u6 = Dropout(0.2)(u6)
c6 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(u6)
c6 = BatchNormalization()(c6)
c6 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(c6)
c6 = BatchNormalization()(c6)
u7 = Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(c6)
u7 = concatenate([u7, c3])
u7 = Dropout(0.2)(u7)
c7 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(u7)
c7 = BatchNormalization()(c7)
c7 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(c7)
c7 = BatchNormalization()(c7)
u8 = Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(c7)
u8 = concatenate([u8, c2])
u8 = Dropout(0.1)(u8)
c8 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(u8)
c8 = BatchNormalization()(c8)
c8 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(c8)
c8 = BatchNormalization()(c8)
u9 = Conv3DTranspose(16, (2, 2, 2), strides=(2, 2, 2), padding='same')(c8)
u9 = concatenate([u9, c1])
u9 = Dropout(0.1)(u9)
c9 = Conv3D(16, (3, 3, 3), activation='relu', padding='same')(u9)
c9 = BatchNormalization()(c9)
c9 = Conv3D(16, (3, 3, 3), activation='relu', padding='same')(c9)
c9 = BatchNormalization()(c9)
outputs = Conv3D(num_classes, (1, 1, 1), activation='sigmoid')(c9)
model = Model(inputs=[inputs], outputs=[outputs])
model.summary()
return model
#Test if everything is working ok.
model = simple_unet_model(112, 192, 160, 3, 4)
print(model.input_shape)
print(model.output_shape)
使用的指标定义如下:
def dice_coef(y_true, y_pred, smooth=1.0):
class_num = 4
for i in range(class_num):
y_true_f = K.flatten(y_true[:,:,:,:,i])
y_pred_f = K.flatten(y_pred[:,:,:,:,i])
intersection = K.sum(y_true_f * y_pred_f)
loss = ((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))
if i == 0:
total_loss = loss
else:
total_loss = total_loss + loss
total_loss = total_loss / class_num
return total_loss
def dice_coef_background(y_true, y_pred, epsilon=1e-6):
intersection = K.sum(K.abs(y_true[:,:,:,:,0] * y_pred[:,:,:,:,0]))
return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,:,0])) + K.sum(K.square(y_pred[:,:,:,:,0])) + epsilon)
def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
intersection = K.sum(K.abs(y_true[:,:,:,:,1] * y_pred[:,:,:,:,1]))
return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,:,1])) + K.sum(K.square(y_pred[:,:,:,:,1])) + epsilon)
def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
intersection = K.sum(K.abs(y_true[:,:,:,:,2] * y_pred[:,:,:,:,2]))
return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,:,2])) + K.sum(K.square(y_pred[:,:,:,:,2])) + epsilon)
def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
intersection = K.sum(K.abs(y_true[:,:,:,:,3] * y_pred[:,:,:,:,3]))
return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,:,3])) + K.sum(K.square(y_pred[:,:,:,:,3])) + epsilon)
def iou(y_true, y_pred):
intersect = K.sum(y_pred*y_true)
union = K.sum(y_pred) + K.sum(y_true) - intersect
iou = K.mean(intersect/union)
return iou
最后我编译并训练了模型:
opt = tf.keras.optimizers.Adam()
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy',dice_coef,dice_coef_background,dice_coef_edema,dice_coef_enhancing,dice_coef_necrotic,iou])
问题
问题在于模型在训练期间达到了良好的性能,但是验证集的指标较低,并且每个时期的结果之间存在很大的差异。该模型过度拟合。\
Epoch 1/40
120/120 [==============================] - 238s 2s/step - loss: 2.7209 - accuracy: 0.6425 - dice_coef: 0.2088 - dice_coef_background: 0.8156 - dice_coef_edema: 0.1983 - dice_coef_enhancing: 0.0544 - dice_coef_necrotic: 0.0264 - iou: 0.2598 - val_loss: 2.8709 - val_accuracy: 0.2091 - val_dice_coef: 0.1299 - val_dice_coef_background: 0.4797 - val_dice_coef_edema: 0.0772 - val_dice_coef_enhancing: 0.0335 - val_dice_coef_necrotic: 0.0183 - val_iou: 0.1319
Epoch 2/40
120/120 [==============================] - 216s 2s/step - loss: 2.3724 - accuracy: 0.8747 - dice_coef: 0.2464 - dice_coef_background: 0.8522 - dice_coef_edema: 0.4389 - dice_coef_enhancing: 0.1433 - dice_coef_necrotic: 0.0453 - iou: 0.3598 - val_loss: 2.9276 - val_accuracy: 0.1521 - val_dice_coef: 0.0673 - val_dice_coef_background: 0.2143 - val_dice_coef_edema: 0.0438 - val_dice_coef_enhancing: 0.0185 - val_dice_coef_necrotic: 0.0101 - val_iou: 0.0433
Epoch 3/40
120/120 [==============================] - 218s 2s/step - loss: 2.0515 - accuracy: 0.9757 - dice_coef: 0.2946 - dice_coef_background: 0.8749 - dice_coef_edema: 0.5525 - dice_coef_enhancing: 0.2858 - dice_coef_necrotic: 0.1102 - iou: 0.4814 - val_loss: 2.9217 - val_accuracy: 0.3761 - val_dice_coef: 0.1115 - val_dice_coef_background: 0.4222 - val_dice_coef_edema: 0.0468 - val_dice_coef_enhancing: 0.0198 - val_dice_coef_necrotic: 0.0117 - val_iou: 0.0905
Epoch 4/40
120/120 [==============================] - 217s 2s/step - loss: 1.7328 - accuracy: 0.9810 - dice_coef: 0.3416 - dice_coef_background: 0.8892 - dice_coef_edema: 0.5858 - dice_coef_enhancing: 0.4982 - dice_coef_necrotic: 0.1832 - iou: 0.5527 - val_loss: 2.1945 - val_accuracy: 0.9201 - val_dice_coef: 0.2789 - val_dice_coef_background: 0.7881 - val_dice_coef_edema: 0.3169 - val_dice_coef_enhancing: 0.3171 - val_dice_coef_necrotic: 0.1716 - val_iou: 0.4425
Epoch 5/40
120/120 [==============================] - 214s 2s/step - loss: 1.5224 - accuracy: 0.9839 - dice_coef: 0.3833 - dice_coef_background: 0.8911 - dice_coef_edema: 0.6061 - dice_coef_enhancing: 0.6402 - dice_coef_necrotic: 0.2313 - iou: 0.5812 - val_loss: 2.2925 - val_accuracy: 0.9038 - val_dice_coef: 0.2828 - val_dice_coef_background: 0.8243 - val_dice_coef_edema: 0.3067 - val_dice_coef_enhancing: 0.2606 - val_dice_coef_necrotic: 0.1402 - val_iou: 0.4851
Epoch 6/40
120/120 [==============================] - 217s 2s/step - loss: 1.4060 - accuracy: 0.9842 - dice_coef: 0.4106 - dice_coef_background: 0.8799 - dice_coef_edema: 0.6582 - dice_coef_enhancing: 0.6870 - dice_coef_necrotic: 0.2489 - iou: 0.5851 - val_loss: 2.3502 - val_accuracy: 0.9472 - val_dice_coef: 0.2996 - val_dice_coef_background: 0.9187 - val_dice_coef_edema: 0.2701 - val_dice_coef_enhancing: 0.2487 - val_dice_coef_necrotic: 0.1310 - val_iou: 0.6124
Epoch 7/40
120/120 [==============================] - 215s 2s/step - loss: 1.3528 - accuracy: 0.9851 - dice_coef: 0.4281 - dice_coef_background: 0.8782 - dice_coef_edema: 0.6995 - dice_coef_enhancing: 0.6842 - dice_coef_necrotic: 0.2635 - iou: 0.5895 - val_loss: 2.3285 - val_accuracy: 0.9339 - val_dice_coef: 0.3099 - val_dice_coef_background: 0.9236 - val_dice_coef_edema: 0.3018 - val_dice_coef_enhancing: 0.2018 - val_dice_coef_necrotic: 0.1679 - val_iou: 0.6461
Epoch 8/40
120/120 [==============================] - 215s 2s/step - loss: 1.3059 - accuracy: 0.9866 - dice_coef: 0.4415 - dice_coef_background: 0.8896 - dice_coef_edema: 0.7004 - dice_coef_enhancing: 0.6973 - dice_coef_necrotic: 0.2965 - iou: 0.6044 - val_loss: 2.2217 - val_accuracy: 0.9359 - val_dice_coef: 0.3149 - val_dice_coef_background: 0.9329 - val_dice_coef_edema: 0.4079 - val_dice_coef_enhancing: 0.1633 - val_dice_coef_necrotic: 0.2071 - val_iou: 0.6430
Epoch 9/40
120/120 [==============================] - 216s 2s/step - loss: 1.2152 - accuracy: 0.9861 - dice_coef: 0.4513 - dice_coef_background: 0.8948 - dice_coef_edema: 0.6920 - dice_coef_enhancing: 0.7071 - dice_coef_necrotic: 0.3856 - iou: 0.6149 - val_loss: 2.2314 - val_accuracy: 0.9252 - val_dice_coef: 0.3241 - val_dice_coef_background: 0.9221 - val_dice_coef_edema: 0.4166 - val_dice_coef_enhancing: 0.1481 - val_dice_coef_necrotic: 0.2039 - val_iou: 0.6461
Epoch 10/40
120/120 [==============================] - 215s 2s/step - loss: 1.1594 - accuracy: 0.9872 - dice_coef: 0.4734 - dice_coef_background: 0.8966 - dice_coef_edema: 0.7030 - dice_coef_enhancing: 0.7099 - dice_coef_necrotic: 0.4277 - iou: 0.6224 - val_loss: 2.4696 - val_accuracy: 0.9276 - val_dice_coef: 0.2894 - val_dice_coef_background: 0.9057 - val_dice_coef_edema: 0.2756 - val_dice_coef_enhancing: 0.1448 - val_dice_coef_necrotic: 0.1100 - val_iou: 0.6290
Epoch 11/40
120/120 [==============================] - 216s 2s/step - loss: 1.1447 - accuracy: 0.9876 - dice_coef: 0.4804 - dice_coef_background: 0.8958 - dice_coef_edema: 0.7047 - dice_coef_enhancing: 0.7125 - dice_coef_necrotic: 0.4381 - iou: 0.6242 - val_loss: 2.5344 - val_accuracy: 0.9241 - val_dice_coef: 0.2800 - val_dice_coef_background: 0.9131 - val_dice_coef_edema: 0.2171 - val_dice_coef_enhancing: 0.1230 - val_dice_coef_necrotic: 0.1255 - val_iou: 0.6576
Epoch 12/40
120/120 [==============================] - 216s 2s/step - loss: 1.1023 - accuracy: 0.9879 - dice_coef: 0.4930 - dice_coef_background: 0.8985 - dice_coef_edema: 0.7110 - dice_coef_enhancing: 0.7122 - dice_coef_necrotic: 0.4745 - iou: 0.6290 - val_loss: 2.6062 - val_accuracy: 0.9382 - val_dice_coef: 0.2735 - val_dice_coef_background: 0.9033 - val_dice_coef_edema: 0.2529 - val_dice_coef_enhancing: 0.0813 - val_dice_coef_necrotic: 0.0596 - val_iou: 0.6296
Epoch 13/40
120/120 [==============================] - 215s 2s/step - loss: 1.0679 - accuracy: 0.9884 - dice_coef: 0.5051 - dice_coef_background: 0.9002 - dice_coef_edema: 0.7180 - dice_coef_enhancing: 0.7266 - dice_coef_necrotic: 0.4875 - iou: 0.6324 - val_loss: 2.5210 - val_accuracy: 0.9272 - val_dice_coef: 0.2792 - val_dice_coef_background: 0.9120 - val_dice_coef_edema: 0.2221 - val_dice_coef_enhancing: 0.1146 - val_dice_coef_necrotic: 0.1423 - val_iou: 0.6307
Epoch 14/40
120/120 [==============================] - 213s 2s/step - loss: 1.0196 - accuracy: 0.9887 - dice_coef: 0.5154 - dice_coef_background: 0.9008 - dice_coef_edema: 0.7280 - dice_coef_enhancing: 0.7398 - dice_coef_necrotic: 0.5126 - iou: 0.6352 - val_loss: 2.5746 - val_accuracy: 0.9279 - val_dice_coef: 0.2686 - val_dice_coef_background: 0.8891 - val_dice_coef_edema: 0.2038 - val_dice_coef_enhancing: 0.0914 - val_dice_coef_necrotic: 0.1301 - val_iou: 0.6138
Epoch 15/40
120/120 [==============================] - 218s 2s/step - loss: 1.0425 - accuracy: 0.9887 - dice_coef: 0.5215 - dice_coef_background: 0.9052 - dice_coef_edema: 0.7161 - dice_coef_enhancing: 0.7349 - dice_coef_necrotic: 0.5064 - iou: 0.6398 - val_loss: 2.6565 - val_accuracy: 0.9344 - val_dice_coef: 0.2715 - val_dice_coef_background: 0.9179 - val_dice_coef_edema: 0.1515 - val_dice_coef_enhancing: 0.1303 - val_dice_coef_necrotic: 0.0618 - val_iou: 0.6619
Epoch 16/40
120/120 [==============================] - 215s 2s/step - loss: 1.0099 - accuracy: 0.9891 - dice_coef: 0.5357 - dice_coef_background: 0.9099 - dice_coef_edema: 0.7392 - dice_coef_enhancing: 0.7433 - dice_coef_necrotic: 0.5077 - iou: 0.6460 - val_loss: 2.6570 - val_accuracy: 0.9405 - val_dice_coef: 0.2600 - val_dice_coef_background: 0.8957 - val_dice_coef_edema: 0.1549 - val_dice_coef_enhancing: 0.0995 - val_dice_coef_necrotic: 0.0885 - val_iou: 0.6371
仅仅是维度的问题(验证样本太少,类不平衡:class_weights[3710.65, 0.78, 0.78, 0.69])还是其他问题?