如何从MNIST数据集中选择每个类别的特定数目

问题描述 投票:0回答:1

我正在使用tensorflow处理Mnist。我需要使用每个类的特定数量的数据来训练我的网络(例如,每个位数500个样本)。我找到了how to sort the DB with class labels

idx = np.argsort(y_train)
x_train_sorted = x_train[idx]
y_train_sorted = y_train[idx]

但是如何选择500个数字,然后将它们与随机播放结合在一起?

python keras mnist
1个回答
0
投票

如果您在一个DataFrame中拥有全部,那么您可以groupby标签然后获得head

import pandas as pd

df = pd.DataFrame({
    'X': [1,2,3,4,5,6,7,8,9,10,11,12],
    'label': ['a','a','a','a','b','b','b','b','c','c','c','c']
})

groups = df.groupby('label')

df2 = groups.head(2)    

print(df2)

结果

    X label
0   1     a
1   2     a
4   5     b
5   6     b
8   9     c
9  10     c

然后您可以将其拆分为X_trainy_train

© www.soinside.com 2019 - 2024. All rights reserved.