从numpy数组复制值以平衡数据集

问题描述 投票:0回答:1

我有一个数据集,其中一个相似的外观类是不平衡的。它是一个数字数据集,其中类标签从1到10。

按标签(y)对训练集进行分组可得出以下结果:

(array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=uint8), array([13861, 10585,  8497,  7458,  6882,  5727,  5595,  5045,  4659,
    4948]))

可以看出113861数据点,而7只有5595数据点。

为了避免17之间的类不平衡,我想为7类添加一些额外的图像。

这是train集:

from scipy.io import loadmat

train = loadmat('train.mat')

extra = loadmat('extra.mat')

trainextra都是字典,每个都有2个密钥Xy

这是trainextra的形状:

train['X'] --> (32, 32, 3, 73257)
# 73257 images of 32x32x3
train['y'] --> (73257,1)
# 73257 labels of corresponding images

extra['X'] --> (32, 32, 3, 531131)
# 531131 images of 32x32x3
extra['y'] --> (531131, 1)
# 531131 labels of corresponding images

现在,我想用train的标签更新extra数据集,主要是将x%中带有标签7extra数据带入train。我怎么能这样做?

我尝试了以下方法:

arr, _ = np.where(extra['y'] == 7)
c = np.concatenate(X_train, extra['X'][arr])

但我得到一个错误说IndexError: index 32 is out of bounds for axis 0 with size 32

python numpy dataset
1个回答
1
投票

这是一个关于numpy数组的工作示例,可以轻松转换为您的案例。正如您编辑的那样,使用numpy.whereextra['y']上找到您想要的标签并保留这些索引。然后将它们与numpy.append一起用于连接(X的最后一个轴和y的第一个轴)你的原始数据集和额外的数据集。

import numpy as np

np.random.seed(100)

# First find the indices of your y_extra with label 7
x_extra = np.random.rand(32, 32, 3, 10)
y_extra = np.random.randint(0, 9, size=(10,1))
indices = np.where(y_extra==7)[0] # indices [3,4] are 7 with seed=100

# Now use this indices to concatenate them in the original datase
np.random.seed(101)
x_original = np.random.rand(32, 32, 3, 10)
y_original = np.random.randint(1, 10, size=(10,1))

print(x_original.shape, x_extra[..., indices].shape) # (32, 32, 3, 10) (32, 32, 3, 2)
print(y_original.shape, y_extra[indices].shape) # (10, 1) (2, 1)

x_final = np.append(x_original, x_extra[..., indices], axis=-1)
y_final = np.append(y_original, y_extra[indices], axis=0)

print(x_final.shape, y_final.shape) # (32, 32, 3, 12) (12, 1)
© www.soinside.com 2019 - 2024. All rights reserved.