选择有1000个标签的输出向量的大小。

问题描述 投票:0回答:1

互联网上的大多数例子都是关于 multi-label 图像分类仅基于一个 few 标签。例如,与 6 类,我们得到。

model = models.Sequential()
model.add(layer=base)
model.add(layer=layers.Flatten())
model.add(layer=layers.Dense(units=256, activation="relu"))
model.add(layer=layers.Dense(units=6, activation="sigmoid"))
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Model)                (None, 7, 7, 512)         14714688  
_________________________________________________________________
flatten_1 (Flatten)          (None, 25088)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               6422784   
_________________________________________________________________
dense_2 (Dense)              (None, 6)                 1542      
=================================================================
Total params: 21,139,014
Trainable params: 13,503,750
Non-trainable params: 7,635,264

然而,对于数据集与 significantly 更多的标签,训练的大小 parameters 爆炸,最终训练过程失败,并伴有 ResourceExhaustedError 错误。例如,对于 3047 我们得到的标签。

model = models.Sequential()
model.add(layer=base)
model.add(layer=layers.Flatten())
model.add(layer=layers.Dense(units=256, activation="relu"))
model.add(layer=layers.Dense(units=3047, activation="sigmoid"))
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Model)                (None, 7, 7, 512)         14714688  
_________________________________________________________________
flatten_1 (Flatten)          (None, 25088)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               6422784   
_________________________________________________________________
dense_2 (Dense)              (None, 3047)              783079    
=================================================================
Total params: 21,920,551
Trainable params: 14,285,287
Non-trainable params: 7,635,264
_________________________________________________________________

很明显,我的网络出了点问题 但不知道如何解决这个问题...

tensorflow keras transfer-learning
1个回答
0
投票

资源耗尽错误与内存问题有关。要么是你的系统中没有足够的内存,要么是代码的其他部分造成了内存问题。

© www.soinside.com 2019 - 2024. All rights reserved.