提取数组每行的非零元素

问题描述 投票:0回答:2

我正在使用 Python 3.9/NumPy 1.22。

假设我有一个 3x3 矩阵:

import numpy as np
x = np.array([[10, 40, 0],
              [0, 40, 90],
              [10, 0, 90]])

所有元素都是>= 0的整数。 每行恰好有 2 个非零整数。

我想提取非零整数来生成 3x2 矩阵

y
这样

y = np.array([[10, 40],
              [40, 90],
              [10, 90]])

使用

numpy.apply_along_axis
numpy.squeeze
和/或
numpy.where
时,我感觉很接近,但我错过了一些东西。

python numpy matrix
2个回答
1
投票

由于您知道每行有相同数量的零,因此您可以安全地删除它们并

reshape

x = np.array([[10,40,0],[0,40,90],[10,0,90]])

out = x[x!=0].reshape(len(x), -1)

输出:

array([[10, 40],
       [40, 90],
       [10, 90]])

为了好玩,如果你没有相同数量的零,你可以将它们移到末尾,然后切片以保留它们的最小数量:

x = np.array([[10,40,0],[0,40,90],[0,0,90]])
# array([[10, 40,  0],
#        [ 0, 40, 90],
#        [ 0,  0, 90]])

m = x!=0
out = np.take_along_axis(x, np.argsort(~m, axis=1),
                         axis=1)[:, :m.sum(axis=1).max()]
# array([[10, 40],
#        [40, 90],
#        [90,  0]])

0
投票

这是一种解决方案,但它需要了解非零元素的确切数量(在本例中为 2)

x = np.array([[10,40,0],[0,40,90],[10,0,90]])  # exactly 2 non-zero per row
idx = np.where(x>0)
idx
Out[90]: 
(array([0, 0, 1, 1, 2, 2], dtype=int64),
 array([0, 1, 1, 2, 0, 2], dtype=int64))
y = x[idx].reshape((3,2))
y
Out[92]: 
array([[10, 40],
       [40, 90],
       [10, 90]])
© www.soinside.com 2019 - 2024. All rights reserved.