大型稀疏 CSR 二进制矩阵乘法结果中的错误

问题描述 投票:0回答:1

这让我很困惑,这是一个已知的错误还是我遗漏了什么?如果出现错误有办法规避吗?


假设我有一个相对较小的二进制 (0/1) n x q

scipy.sparse.csr_matrix
,如:

import numpy as np
from scipy import sparse

def get_dummies(vec, vec_max):
    vec_size = vec.size
    Z = sparse.csr_matrix((np.ones(vec_size), (np.arange(vec_size), vec)), shape=(vec_size, vec_max), dtype=np.uint8)
    return Z

q = 100
ns = np.round(np.random.random(q)*100).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<5171x100 sparse matrix of type '<class 'numpy.uint8'>'
    with 5171 stored elements in Compressed Sparse Row format>

这里

Z
是一个标准的虚拟变量矩阵,其中n=5171个观测值和q=100个变量:

Z[:5, :5].toarray()
array([[1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0]], dtype=uint8)

例如,如果前 5 个变量有...

ns[:5]
array([21, 22, 37, 24, 99], dtype=int16)

频率,我们也会在

Z
的列总和中看到这一点:

Z[:, :5].sum(axis=0)
matrix([[21, 22, 37, 24, 99]], dtype=uint64)

现在,正如预期的那样,如果我乘以

Z.T @ Z
,我应该得到一个 q x q 对角矩阵,对角线上是 q 个变量的频率:

print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray()
(100, 100)
[[21  0  0  0  0]
 [ 0 22  0  0  0]
 [ 0  0 37  0  0]
 [ 0  0  0 24  0]
 [ 0  0  0  0 99]]

现在来说一下 bug :如果 n 真的很大(对我来说,它已经发生在 n = 100K 左右):

q = 1000
ns = np.round(np.random.random(q)*1000).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<495509x1000 sparse matrix of type '<class 'numpy.uint8'>'
    with 495509 stored elements in Compressed Sparse Row format>

频率很大,

Z
列的总和符合预期:

print(ns[:5])
Z[:, :5].sum(axis=0)
array([485, 756, 380,  87, 454], dtype=int16)
matrix([[485, 756, 380,  87, 454]], dtype=uint64)

但是

Z.T @ Z
搞砸了!从某种意义上说,我在对角线上没有得到正确的频率:

print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray())
(1000, 1000)
[[229   0   0   0   0]
 [  0 244   0   0   0]
 [  0   0 124   0   0]
 [  0   0   0  87   0]
 [  0   0   0   0 198]]

令人惊讶的是,与真实频率有一些关系:

import matplotlib.pyplot as plt

plt.scatter(ns, (Z.T @ Z).diagonal())
plt.xlabel('real frequencies')
plt.ylabel('values on ZZ diagonal')
plt.show()

发生什么事了?

我正在使用标准 colab:

import scipy as sc

print(np.__version__)
print(sc.__version__)
1.25.2
1.11.4

PS:显然,如果我只想要

Z.T @ Z
的输出矩阵,有更简单的方法来获取它,这是一个非常简化的简化问题,谢谢。

python scipy sparse-matrix
1个回答
0
投票

您正在使用

uint8
:

print((Z.T @ Z)[:5, :5].dtype)
# uint8

所以结果是溢出的。该模式是因为结果是模 256 的真实结果。如果将 dtype 更改为

uint16
,结果就会消失。

© www.soinside.com 2019 - 2024. All rights reserved.