如何在没有双循环的情况下找到许多文本中每个单独文本之间的共同标记的数量?

问题描述 投票:1回答:1

我有大量小的标记化文本,我想以上三角矩阵M的形式找到所有这些文本之间的共同标记的数量(矩阵也可以是对称的)。如果M [i,j] = 1,则意味着i和j共有1个令牌。

我找不到另一种方法,而不是双循环,这不是时间效率。有没有?

在下面的代码中,subjects是每个文本中令牌列表的列表。

from scipy.sparse import lil_matrix

n = len(subjects)
M = lil_matrix((n, n))

i = 0
for subj_1 in subjects:
    j = 0
    for subj_2 in subjects[i+1:]:
        inter_len = len(list(set(subj_1).intersection(subj_2)))
        if inter_len>0:
            M[i,j+i+1] = inter_len
        j += 1
    i += 1
python token
1个回答
0
投票

我将向您介绍三种选择。您的基本算法可以通过多种方式进行改进。不要附加到稀疏矩阵,只需使用预分配的数组,即使您没有完全填充它。除此之外,您可以在开头只将对象转换为一次,以避免重复工作。所以你得到:

import numpy as np

def count_common_tokens(subjects):
    n = len(subjects)
    counts = np.zeros((n, n), dtype=np.int32)
    subjects_sets = [set(subject) for subject in subjects]
    for i1, subj_1 in enumerate(subjects_sets):
        for i2 in range(i1 + 1, n):
            subj_2 = subjects_sets[i2]
            counts[i1, i2] = len(subj_1.intersection(subj_2))
    return counts

问题本质上是二次复杂性。但我们可以尝试使用NumPy对其进行矢量化。

import numpy as np

def count_common_tokens_vec(subjects):
    n = len(subjects)
    # Concatenate all subjects
    all_subjects = np.concatenate(subjects)
    # Make subject ids from subject lengths
    lens = [len(subject) for subject in subjects]
    subject_ids = np.repeat(np.arange(n), lens)
    # Find unique token ids
    all_tokens, token_ids = np.unique(all_subjects, return_inverse=True)
    # Make array where each row represents the token presents in each subject
    subject_token = np.zeros((n, len(all_tokens)), dtype=np.int32)
    np.add.at(subject_token, (subject_ids, token_ids), 1)
    subject_token = subject_token.astype(bool)
    # Logical and with itself to find number of common tokens
    counts = np.count_nonzero(subject_token[:, np.newaxis] & subject_token[np.newaxis, :], axis=-1)
    return counts

这为你提供了完整的计数矩阵(而不仅仅是上三角形),但它可能需要大量的内存,按照O(num_subjects x num_subjecs x num_tokens)的顺序,所以它可能不适用于一个大问题。然而,如果你真的需要,我们可以尝试用Numba加快速度。它强迫你做一些不同的事情,使用数字数组而不是字符串集(也许有更好的方法来完成第一部分),但我们也可以得到我们想要的结果。

import numpy as np
import numba as nb

def count_common_tokens_nb(subjects):
    n = len(subjects)
    # Output array
    counts = np.zeros((n, n), dtype=np.int32)
    # Concatenate all subjects
    all_subjects = np.concatenate(subjects)
    # Find token ids for concatenated subjects
    _, token_ids = np.unique(all_subjects, return_inverse=True)
    # Split token ids and remove duplicates
    lens = [len(subject) for subject in subjects]
    subjects_sets = [np.unique(s) for s in np.split(token_ids, np.cumsum(lens)[:-1])]
    # Do the counting
    _count_common_tokens_nb_inner(counts, subjects_sets)
    return counts

@nb.njit(parallel=True)
def _count_common_tokens_nb_inner(counts, subjects_sets):
    n = len(subjects_sets)
    for i1 in nb.prange(n):
        subj_1 = subjects_sets[i1]
        for i2 in nb.prange(i1 + 1, n):
            subj_2 = subjects_sets[i2]
            c = 0
            for t1 in subj_1:
                for t2 in subj_2:
                    c += int(t1 == t2)
            counts[i1, i2] = c
    return counts

这是一个快速测试和一个小的性能比较。

import random
import string
import numpy as np

NUM_SUBJECTS = 1000
MAX_TOKENS_SUBJECT = 20
NUM_TOKENS = 5000
MAX_LEN_TOKEN = 10

# Make random input
random.seed(0)
letters = list(string.ascii_letters)
tokens = np.array(list(set(''.join(random.choices(letters, k=random.randint(1, MAX_LEN_TOKEN)))
                           for _ in range(NUM_TOKENS))))
subjects = [np.array(random.choices(tokens, k=random.randint(1, MAX_TOKENS_SUBJECT)))
            for _ in range(NUM_SUBJECTS)]
# Do counts
res1 = count_common_tokens(subjects)
res2 = count_common_tokens_vec(subjects)
res3 = count_common_tokens_nb(subjects)
# Check results
print(np.all(np.triu(res1, 1) == np.triu(res2, 1)))
# True
print(np.all(np.triu(res1, 1) == np.triu(res3, 1)))
# True

# Test performance
%timeit count_common_tokens(subjects)
# 196 ms ± 2.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_common_tokens_vec(subjects)
# 5.09 s ± 30.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_common_tokens_nb(subjects)
# 65.2 ms ± 886 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

因此,矢量化解决方案效果不佳,但是使用Numba可以获得显着的加速。

© www.soinside.com 2019 - 2024. All rights reserved.