我有大量小的标记化文本,我想以上三角矩阵M的形式找到所有这些文本之间的共同标记的数量(矩阵也可以是对称的)。如果M [i,j] = 1,则意味着i和j共有1个令牌。
我找不到另一种方法,而不是双循环,这不是时间效率。有没有?
在下面的代码中,subjects
是每个文本中令牌列表的列表。
from scipy.sparse import lil_matrix
n = len(subjects)
M = lil_matrix((n, n))
i = 0
for subj_1 in subjects:
j = 0
for subj_2 in subjects[i+1:]:
inter_len = len(list(set(subj_1).intersection(subj_2)))
if inter_len>0:
M[i,j+i+1] = inter_len
j += 1
i += 1
我将向您介绍三种选择。您的基本算法可以通过多种方式进行改进。不要附加到稀疏矩阵,只需使用预分配的数组,即使您没有完全填充它。除此之外,您可以在开头只将对象转换为一次,以避免重复工作。所以你得到:
import numpy as np
def count_common_tokens(subjects):
n = len(subjects)
counts = np.zeros((n, n), dtype=np.int32)
subjects_sets = [set(subject) for subject in subjects]
for i1, subj_1 in enumerate(subjects_sets):
for i2 in range(i1 + 1, n):
subj_2 = subjects_sets[i2]
counts[i1, i2] = len(subj_1.intersection(subj_2))
return counts
问题本质上是二次复杂性。但我们可以尝试使用NumPy对其进行矢量化。
import numpy as np
def count_common_tokens_vec(subjects):
n = len(subjects)
# Concatenate all subjects
all_subjects = np.concatenate(subjects)
# Make subject ids from subject lengths
lens = [len(subject) for subject in subjects]
subject_ids = np.repeat(np.arange(n), lens)
# Find unique token ids
all_tokens, token_ids = np.unique(all_subjects, return_inverse=True)
# Make array where each row represents the token presents in each subject
subject_token = np.zeros((n, len(all_tokens)), dtype=np.int32)
np.add.at(subject_token, (subject_ids, token_ids), 1)
subject_token = subject_token.astype(bool)
# Logical and with itself to find number of common tokens
counts = np.count_nonzero(subject_token[:, np.newaxis] & subject_token[np.newaxis, :], axis=-1)
return counts
这为你提供了完整的计数矩阵(而不仅仅是上三角形),但它可能需要大量的内存,按照O(num_subjects x num_subjecs x num_tokens)
的顺序,所以它可能不适用于一个大问题。然而,如果你真的需要,我们可以尝试用Numba加快速度。它强迫你做一些不同的事情,使用数字数组而不是字符串集(也许有更好的方法来完成第一部分),但我们也可以得到我们想要的结果。
import numpy as np
import numba as nb
def count_common_tokens_nb(subjects):
n = len(subjects)
# Output array
counts = np.zeros((n, n), dtype=np.int32)
# Concatenate all subjects
all_subjects = np.concatenate(subjects)
# Find token ids for concatenated subjects
_, token_ids = np.unique(all_subjects, return_inverse=True)
# Split token ids and remove duplicates
lens = [len(subject) for subject in subjects]
subjects_sets = [np.unique(s) for s in np.split(token_ids, np.cumsum(lens)[:-1])]
# Do the counting
_count_common_tokens_nb_inner(counts, subjects_sets)
return counts
@nb.njit(parallel=True)
def _count_common_tokens_nb_inner(counts, subjects_sets):
n = len(subjects_sets)
for i1 in nb.prange(n):
subj_1 = subjects_sets[i1]
for i2 in nb.prange(i1 + 1, n):
subj_2 = subjects_sets[i2]
c = 0
for t1 in subj_1:
for t2 in subj_2:
c += int(t1 == t2)
counts[i1, i2] = c
return counts
这是一个快速测试和一个小的性能比较。
import random
import string
import numpy as np
NUM_SUBJECTS = 1000
MAX_TOKENS_SUBJECT = 20
NUM_TOKENS = 5000
MAX_LEN_TOKEN = 10
# Make random input
random.seed(0)
letters = list(string.ascii_letters)
tokens = np.array(list(set(''.join(random.choices(letters, k=random.randint(1, MAX_LEN_TOKEN)))
for _ in range(NUM_TOKENS))))
subjects = [np.array(random.choices(tokens, k=random.randint(1, MAX_TOKENS_SUBJECT)))
for _ in range(NUM_SUBJECTS)]
# Do counts
res1 = count_common_tokens(subjects)
res2 = count_common_tokens_vec(subjects)
res3 = count_common_tokens_nb(subjects)
# Check results
print(np.all(np.triu(res1, 1) == np.triu(res2, 1)))
# True
print(np.all(np.triu(res1, 1) == np.triu(res3, 1)))
# True
# Test performance
%timeit count_common_tokens(subjects)
# 196 ms ± 2.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_common_tokens_vec(subjects)
# 5.09 s ± 30.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_common_tokens_nb(subjects)
# 65.2 ms ± 886 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
因此,矢量化解决方案效果不佳,但是使用Numba可以获得显着的加速。