我正在使用 Python 3.9/NumPy 1.22。
假设我有一个 3x3 矩阵:
import numpy as np
x = np.array([[10, 40, 0],
[0, 40, 90],
[10, 0, 90]])
所有元素都是>= 0的整数。 每行恰好有 2 个非零整数。
我想提取非零整数来生成 3x2 矩阵
y
这样
y = np.array([[10, 40],
[40, 90],
[10, 90]])
使用
numpy.apply_along_axis
、numpy.squeeze
和/或 numpy.where
时,我感觉很接近,但我错过了一些东西。
reshape
:
x = np.array([[10,40,0],[0,40,90],[10,0,90]])
out = x[x!=0].reshape(len(x), -1)
输出:
array([[10, 40],
[40, 90],
[10, 90]])
为了好玩,如果你没有相同数量的零,你可以将它们移到末尾,然后切片以保留它们的最小数量:
x = np.array([[10,40,0],[0,40,90],[0,0,90]])
# array([[10, 40, 0],
# [ 0, 40, 90],
# [ 0, 0, 90]])
m = x!=0
out = np.take_along_axis(x, np.argsort(~m, axis=1),
axis=1)[:, :m.sum(axis=1).max()]
# array([[10, 40],
# [40, 90],
# [90, 0]])
这是一种解决方案,但它需要了解非零元素的确切数量(在本例中为 2)
x = np.array([[10,40,0],[0,40,90],[10,0,90]]) # exactly 2 non-zero per row
idx = np.where(x>0)
idx
Out[90]:
(array([0, 0, 1, 1, 2, 2], dtype=int64),
array([0, 1, 1, 2, 0, 2], dtype=int64))
y = x[idx].reshape((3,2))
y
Out[92]:
array([[10, 40],
[40, 90],
[10, 90]])