在给定范围内具有频率k的总数

问题描述 投票:1回答:2

如何在给定数组中找到特定范围(kl)中具有频率= r的总数。总共有10 ^ 5个格式为l,r的查询,并且每个查询都是基于先前查询的答案构建的。特别是,在每次查询之后,我们通过查询结果增加l,如果l> r,则交换lr。请注意0<=a[i]<=10^9。数组中的元素总数为n=10^5

我的尝试:

n,k,q = map(int,input().split())
a = list(map(int,input().split()))
ans = 0
for _ in range(q):
    l,r = map(int,input().split())
    l+=ans
    l%=n
    r+=ans
    r%=n
    if l>r:
        l,r = r,l
    d = {}
    for i in a[l:r+1]:
        try:
            d[i]+=1
        except:
            d[i] = 1
    curr_ans = 0
    for i in d.keys():
        if d[i]==k:
            curr_ans+=1
    ans = curr_ans
    print(ans)

样本输入: 5 2 3 7 6 6 5 5 0 4 3 0 4 1

样本输出: 2 1 1

arrays algorithm data-structures range-query
2个回答
0
投票

如果数组中不同值的数量不是太大,您可以考虑存储数组,只要输入数组,每个唯一值一个,计算值的出现次数,直到每个点。然后,您只需要从起始值中减去结束值,以查找有多少频率匹配:

def range_freq_queries(seq, k, queries):
    n = len(seq)
    c = freq_counts(seq)
    result = [0] * len(queries)
    offset = 0
    for i, (l, r) in enumerate(queries):
        result[i] = range_freq_matches(c, offset, l, r, k, n)
        offset = result[i]
    return result

def freq_counts(seq):
    s = {v: i for i, v in enumerate(set(seq))}
    counts = [None] * (len(seq) + 1)
    counts[0] = [0] * len(s)
    for i, v in enumerate(seq, 1):
        counts[i] = list(counts[i - 1])
        j = s[v]
        counts[i][j] += 1
    return counts

def range_freq_matches(counts, offset, start, end, k, n):
    start, end = sorted(((start + offset) % n, (end + offset) % n))
    num = 0
    return sum(1 for cs, ce in zip(counts[start], counts[end + 1]) if ce - cs == k)

seq = [7, 6, 6, 5, 5]
k = 2
queries = [(0, 4), (3, 0), (4, 1)]
print(range_freq_queries(seq, k, queries))
# [2, 1, 1]

你也可以使用NumPy更快地完成它。由于每个结果都取决于前一个结果,所以你必须循环任何情况,但你可以使用Numba来真正加速:

import numpy as np
import numba as nb

def range_freq_queries_np(seq, k, queries):
    seq = np.asarray(seq)
    c = freq_counts_np(seq)
    return _range_freq_queries_np_nb(seq, k, queries, c)

@nb.njit  # This is not necessary but will make things faster
def _range_freq_queries_np_nb(seq, k, queries, c):
    n = len(seq)
    offset = np.int32(0)
    out = np.empty(len(queries), dtype=np.int32)
    for i, (l, r) in enumerate(queries):
        l = (l + offset) % n
        r = (r + offset) % n
        l, r = min(l, r), max(l, r)
        out[i] = np.sum(c[r + 1] - c[l] == k)
        offset = out[i]
    return out

def freq_counts_np(seq):
    uniq = np.unique(seq)
    seq_pad = np.concatenate([[uniq.max() + 1], seq])
    comp = seq_pad[:, np.newaxis] == uniq
    return np.cumsum(comp, axis=0)

seq = np.array([7, 6, 6, 5, 5])
k = 2
queries = [(0, 4), (3, 0), (4, 1)]
print(range_freq_queries_np(seq, k, queries))
# [2 1 2]

让我们将它与原始算法进行比较:

from collections import Counter

def range_freq_queries_orig(seq, k, queries):
    n = len(seq)
    ans = 0
    counter = Counter()
    out = [0] * len(queries)
    for i, (l, r) in enumerate(queries):
        l += ans
        l %= n
        r += ans
        r %= n
        if l > r:
            l, r = r, l
        counter.clear()
        counter.update(seq[l:r+1])
        ans = sum(1 for v in counter.values() if v == k)
        out[i] = ans
    return out

这是一个快速测试和时间:

import random
import numpy

# Make random input
random.seed(0)
seq = random.choices(range(1000), k=5000)
queries = [(random.choice(range(len(seq))), random.choice(range(len(seq))))
           for _ in range(20000)]
k = 20
# Input as array for NumPy version
seq_arr = np.asarray(seq)
# Check all functions return the same result
res1 = range_freq_queries_orig(seq, k, queries)
res2 = range_freq_queries(seq, k, queries)
print(all(r1 == r2 for r1, r2 in zip(res1, res2)))
# True
res3 = range_freq_queries_np(seq_arr, k, queries)
print(all(r1 == r3 for r1, r3 in zip(res1, res3)))
# True

# Timings
%timeit range_freq_queries_orig(seq, k, queries)
# 3.07 s ± 1.11 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit range_freq_queries(seq, k, queries)
# 1.1 s ± 307 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit range_freq_queries_np(seq_arr, k, queries)
# 265 ms ± 726 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

显然,这有效性取决于数据的特征。特别是,如果重复值较少,构造计数表的时间和内存成本将接近O(n2)。


0
投票

假设输入数组是A|A|=n。我将假设A中不同元素的数量远小于n。

我们可以将A划分为sqrt(n)段,每个段的大小为sqrt(n)。对于这些段中的每一个,我们可以计算从元素到计数的映射。构建这些映射需要O(n)时间。

完成预处理后,我们可以通过将完全包含在(l,r)中的所有映射相加来回答每个查询,其中最多包含sqrt(n),然后添加任何额外元素(或者将一个段添加到并减去) ,也是sqrt(n)。

如果有k个不同的元素,则需要O(sqrt(n)* k),因此在最坏的情况下O(n),如果事实上A的每个元素都是不同的。

在组合散列和额外元素时,您可以跟踪具有所需计数的元素。

© www.soinside.com 2019 - 2024. All rights reserved.