我正在尝试编写一个 pyspark 脚本来生成所有素数 <= a given number. For example, generate all the primes upto 1 billion.
我写了以下代码。对于小数量来说它表现得很好,但是一旦数量达到 1 亿,脚本就表现不佳。
from pyspark.sql import SparkSession
from math import isqrt
def is_perfect_square(num):
root = isqrt(num)
return root*root == num
def sieve_of_eratosthenes_partition(iterator):
upper_limit = max(iterator) # Upper limit for prime number generation
prime_flag = [True] * len(iterator) # Initialize boolean array for primes
result = []
cur_prime = 2
while cur_prime * cur_prime <= upper_limit:
i = 0
if ((cur_prime % 2 == 0 and cur_prime != 2) or is_perfect_square(cur_prime)):
cur_prime += 1
else:
for num in iterator:
if(num % cur_prime == 0 and num != cur_prime):
prime_flag[i] = False
i += 1
cur_prime += 1
for num, is_prime in zip(iterator, prime_flag):
if is_prime and num > 1:
result.append(num)
return result
spark = SparkSession.builder.appName("SievePrimesMapPartitions").getOrCreate()
n = 10**7 # End range
numbers = spark.sparkContext.parallelize(range(1, n+1), 1000)
result_rdd = numbers.mapPartitions(sieve_of_eratosthenes_partition)
# result_rdd.map(str).saveAsTextFile("primes")
primes = result_rdd.collect()
print(primes)
print(len(primes))
我将范围 (1, 1000 万) 分为 1000 个分区,例如 (1.....10,000)、(10,0001....20,000) 等等。对于每个分区,我都应用筛函数。筛子迭代地过滤掉给定范围内的所有素数倍数。
我可以看到脚本中的限制因素。对于具有最小数字范围的分区,例如 (1...10,000),筛函数将仅迭代到 100。 但是,对于具有最大数字范围的分区,即最后一个分区 (9,990,000....10,000,000),筛函数将迭代到 sqrt(1000 万)。实际上,我的脚本的性能取决于处理具有最大数字的分区所花费的时间。
我该如何改进?还有其他方法可以对我的数据集进行分区吗? 我想到的另一个想法是制作一个达到 sqrt(给定数字)的筛子。然后将此筛子分配给节点。在每个节点上,过滤掉各个筛子的所有倍数,并将结果组合起来得到素数列表。这看起来是不是有进步?
就像您所想的那样,扩展埃拉托斯特尼筛法算法的一个好方法是首先创建一个较小的筛子,直到给定数字的平方根,然后将该筛子分布到所有节点上。
以下是实现此方法的方法:
def soe_sqrt(n):
limit = isqrt(n)
prime_flag = [True] * (limit + 1)
primes = []
for p in range(2, limit + 1):
if prime_flag[p]:
primes.append(p)
for i in range(p * p, limit + 1, p):
prime_flag[i] = False
return primes
然后你将其广播给你的火花:
spark = SparkSession.builder.appName("SievePrimesMapPartitions").getOrCreate()
n = 10**9 # 1 billion
primes_upto_sqrt_n = soe_sqrt(n)
primes_broadcast = spark.sparkContext.broadcast(primes_upto_sqrt_n)
现在您只需让集群使用广播的素数列表来过滤掉它们的倍数。
def soe_partition(primes_broadcast):
def filter_primes(iterator):
primes = primes_broadcast.value
numbers = set(iterator)
for prime in primes:
multiples = {prime * i for i in range(2, (max(numbers) // prime) + 1)}
numbers -= multiples
return list(numbers)
return filter_primes
numbers = spark.sparkContext.parallelize(range(2, n + 1), 1000)
result_rdd = numbers.mapPartitions(soe_partition(primes_broadcast))
现在您只需使用简单的
primes = result_rdd.collect()
打印结果即可
print(f"found {len(primes)} primes up to {n}")