如何使我的埃拉托斯特尼筛的 pyspark 代码可扩展用于大范围的数字?

问题描述 投票:0回答:1

我正在尝试编写一个 pyspark 脚本来生成所有素数 <= a given number. For example, generate all the primes upto 1 billion.

我写了以下代码。对于小数量来说它表现得很好,但是一旦数量达到 1 亿,脚本就表现不佳。

from pyspark.sql import SparkSession
from math import isqrt

def is_perfect_square(num):
  root = isqrt(num)
  return root*root == num


def sieve_of_eratosthenes_partition(iterator):
    upper_limit = max(iterator) # Upper limit for prime number generation
    prime_flag = [True] * len(iterator) # Initialize boolean array for primes
    result = []
    cur_prime = 2

    while cur_prime * cur_prime <= upper_limit:
        i = 0
        if ((cur_prime % 2 == 0 and cur_prime != 2) or is_perfect_square(cur_prime)):
          cur_prime += 1
        else:
          for num in iterator:
            if(num % cur_prime == 0 and num != cur_prime):
              prime_flag[i] = False
            i += 1  
          cur_prime += 1

    for num, is_prime in zip(iterator, prime_flag):
        if is_prime and num > 1:
            result.append(num)
    return result

spark = SparkSession.builder.appName("SievePrimesMapPartitions").getOrCreate()
n = 10**7 # End range 

numbers = spark.sparkContext.parallelize(range(1, n+1), 1000)
result_rdd = numbers.mapPartitions(sieve_of_eratosthenes_partition)

# result_rdd.map(str).saveAsTextFile("primes")

primes = result_rdd.collect()
print(primes)
print(len(primes))

我将范围 (1, 1000 万) 分为 1000 个分区,例如 (1.....10,000)、(10,0001....20,000) 等等。对于每个分区,我都应用筛函数。筛子迭代地过滤掉给定范围内的所有素数倍数。

我可以看到脚本中的限制因素。对于具有最小数字范围的分区,例如 (1...10,000),筛函数将仅迭代到 100。 但是,对于具有最大数字范围的分区,即最后一个分区 (9,990,000....10,000,000),筛函数将迭代到 sqrt(1000 万)。实际上,我的脚本的性能取决于处理具有最大数字的分区所花费的时间。

我该如何改进?还有其他方法可以对我的数据集进行分区吗? 我想到的另一个想法是制作一个达到 sqrt(给定数字)的筛子。然后将此筛子分配给节点。在每个节点上,过滤掉各个筛子的所有倍数,并将结果组合起来得到素数列表。这看起来是不是有进步?

python apache-spark pyspark primes sieve-of-eratosthenes
1个回答
0
投票

就像您所想的那样,扩展埃拉托斯特尼筛法算法的一个好方法是首先创建一个较小的筛子,直到给定数字的平方根,然后将该筛子分布到所有节点上。

以下是实现此方法的方法:

    def soe_sqrt(n):
        limit = isqrt(n)
        prime_flag = [True] * (limit + 1)
        primes = []
        for p in range(2, limit + 1):
            if prime_flag[p]:
                primes.append(p)
                for i in range(p * p, limit + 1, p):
                    prime_flag[i] = False
        return primes

然后你将其广播给你的火花:

    spark = SparkSession.builder.appName("SievePrimesMapPartitions").getOrCreate()
    n = 10**9  # 1 billion
    primes_upto_sqrt_n = soe_sqrt(n)
    primes_broadcast = spark.sparkContext.broadcast(primes_upto_sqrt_n)

现在您只需让集群使用广播的素数列表来过滤掉它们的倍数。

    def soe_partition(primes_broadcast):
        def filter_primes(iterator):
            primes = primes_broadcast.value
            numbers = set(iterator)
            for prime in primes:
                multiples = {prime * i for i in range(2, (max(numbers) // prime) + 1)}
                numbers -= multiples
            return list(numbers)

        return filter_primes


    numbers = spark.sparkContext.parallelize(range(2, n + 1), 1000)
    result_rdd = numbers.mapPartitions(soe_partition(primes_broadcast))

现在您只需使用简单的

primes = result_rdd.collect()
打印结果即可
print(f"found {len(primes)} primes up to {n}")

© www.soinside.com 2019 - 2024. All rights reserved.