这让我很困惑,这是一个已知的错误还是我遗漏了什么?如果出现错误有办法规避吗?
假设我有一个相对较小的二进制 (0/1) n x q
scipy.sparse.csr_matrix
,如:
import numpy as np
from scipy import sparse
def get_dummies(vec, vec_max):
vec_size = vec.size
Z = sparse.csr_matrix((np.ones(vec_size), (np.arange(vec_size), vec)), shape=(vec_size, vec_max), dtype=np.uint8)
return Z
q = 100
ns = np.round(np.random.random(q)*100).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<5171x100 sparse matrix of type '<class 'numpy.uint8'>'
with 5171 stored elements in Compressed Sparse Row format>
这里
Z
是一个标准的虚拟变量矩阵,其中n=5171个观测值和q=100个变量:
Z[:5, :5].toarray()
array([[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0]], dtype=uint8)
例如,如果前 5 个变量有...
ns[:5]
array([21, 22, 37, 24, 99], dtype=int16)
频率,我们也会在
Z
的列总和中看到这一点:
Z[:, :5].sum(axis=0)
matrix([[21, 22, 37, 24, 99]], dtype=uint64)
现在,正如预期的那样,如果我乘以
Z.T @ Z
,我应该得到一个 q x q 对角矩阵,对角线上是 q 个变量的频率:
print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray()
(100, 100)
[[21 0 0 0 0]
[ 0 22 0 0 0]
[ 0 0 37 0 0]
[ 0 0 0 24 0]
[ 0 0 0 0 99]]
现在来说一下 bug :如果 n 真的很大(对我来说,它已经发生在 n = 100K 左右):
q = 1000
ns = np.round(np.random.random(q)*1000).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<495509x1000 sparse matrix of type '<class 'numpy.uint8'>'
with 495509 stored elements in Compressed Sparse Row format>
频率很大,
Z
列的总和符合预期:
print(ns[:5])
Z[:, :5].sum(axis=0)
array([485, 756, 380, 87, 454], dtype=int16)
matrix([[485, 756, 380, 87, 454]], dtype=uint64)
但是
Z.T @ Z
搞砸了!从某种意义上说,我在对角线上没有得到正确的频率:
print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray())
(1000, 1000)
[[229 0 0 0 0]
[ 0 244 0 0 0]
[ 0 0 124 0 0]
[ 0 0 0 87 0]
[ 0 0 0 0 198]]
令人惊讶的是,与真实频率有一些关系:
import matplotlib.pyplot as plt
plt.scatter(ns, (Z.T @ Z).diagonal())
plt.xlabel('real frequencies')
plt.ylabel('values on ZZ diagonal')
plt.show()
发生什么事了?
我正在使用标准 colab:
import scipy as sc
print(np.__version__)
print(sc.__version__)
1.25.2
1.11.4
PS:显然,如果我只想要
Z.T @ Z
的输出矩阵,有更简单的方法来获取它,这是一个非常简化的简化问题,谢谢。
您正在使用
uint8
:
print((Z.T @ Z)[:5, :5].dtype)
# uint8
所以结果是溢出的。该模式是因为结果是模 256 的真实结果。如果将 dtype 更改为
uint16
,结果就会消失。