我正在尝试遍历数据的所有样本;这是它的外观和形状:
processed_data3
array([[[[1.81673904e-05, 5.00669221e-01],
[1.65148740e-02, 5.52741534e-01],
[1.55841024e-02, 4.39919449e-01],
...,
[3.80455403e-03, 5.00042373e-01],
[6.80686618e-01, 4.78582767e-01],
[7.49290676e-04, 5.30804954e-01]],
processed_data3.shape
(100, 64, 256, 2)
我想遍历数据的所有100个样本,并且与仅对一个样本执行的操作相同:
mask = np.zeros([64,256])
for i in range(64):
for j in range(256):
if processed_data3[0,i,j,0] >0.1:
mask[i,j] = 1
[基本上,我想将这100个样本中的每个样本都存储在其自己的遮罩数组中,但是我不确定如何做到这一点。根据这个样本有什么建议吗?预先非常感谢!
您可以使用np.where()
获得所需的结果。无需循环。试试这个:
mask = np.where(processed_data3>0.1, 1, 0)[0][0]