到目前为止我写了:
n=1000
solutions=[]
for i in range(1,n+1):
for j in range(1,n+1):
if str((i+j)**2)==str(i)+str(j):
solutions.append("("+str(i)+"+"+str(j)+")^2 = "+str((i+j)**2))
for solution in solutions:
print(solution)
这在我的电脑上需要1.03秒。有没有更快的方法来实施比较?我在vectorisation上找到了一个页面,但我不确定如何生成列表,然后我需要进行矢量化比较。
通过搜索满足目标范围内给定方格的方程的(x, y)
对,可以更快地完成此操作。事实上,这减少了从O(n^2)
到O(nlogn)
时间复杂度的问题。
def split_root(n):
div = 10
while div < n:
x, y = divmod(n, div)
div *= 10
if not y or y < div // 100: continue
if (x + y) ** 2 == n: yield x, y
然后迭代可能的方块:
def squares(n):
for i in range(n):
for sr in split_root(i ** 2):
yield "({}+{})^2 = {}".format(*sr, sum(sr)**2)
用法示例:
print("\n".join(squares(100000)))
输出:
(8+1)^2 = 81
(20+25)^2 = 2025
(30+25)^2 = 3025
(88+209)^2 = 88209
(494+209)^2 = 494209
(494+1729)^2 = 4941729
(744+1984)^2 = 7441984
(2450+2500)^2 = 24502500
(2550+2500)^2 = 25502500
(5288+1984)^2 = 52881984
(6048+1729)^2 = 60481729
(3008+14336)^2 = 300814336
(4938+17284)^2 = 493817284
(60494+17284)^2 = 6049417284
(68320+14336)^2 = 6832014336
为了比较,您的原始解决方案 -
def op_solver(n):
solutions = []
for i in range(1,n+1):
for j in range(1,n+1):
if str((i+j)**2)==str(i)+str(j):
solutions.append("("+str(i)+"+"+str(j)+")^2 = "+str((i+j)**2))
return solutions
>>> timeit("op_solver(1000)", setup="from __main__ import op_solver", number=5) / 5
0.8715057126013562
我的解决方案
>>> timeit("list(squares(2000))", setup="from __main__ import squares", number=100) / 100
0.006898956680088304
对于您的示例使用范围,大约需要125倍的加速,并且随着n
的增长,它会渐渐快速地运行。
这也有利于比numpy解决方案更快更简单,当然不需要numpy。如果你确实需要更快的版本,我相信你甚至可以对我的代码进行矢量化以获得两全其美的效果。
您可以通过避免字符串操作来加快计算速度。而不是连接字符串,使用i * 10**(int(math.log10(j))+1) + j
以数字方式“连接”:
In [457]: i, j = 20, 25; i * 10**(int(math.log10(j))+1) + j
Out[457]: 2025
您还可以使用NumPy来矢量化计算:
import numpy as np
n = 1000
def using_numpy(n):
i = range(1, n+1)
j = range(1, n+1)
I, J = np.meshgrid(i, j)
left = (I+J)**2
j_digits = np.log10(J).astype(int) + 1
right = I*10**j_digits + J
mask = left == right
solutions = ['({i}+{j})^2 = {k}'.format(i=i, j=j, k=k)
for i, j, k in zip(I[mask], J[mask], left[mask])]
return solutions
def using_str(n):
solutions=[]
for i in range(1,n+1):
for j in range(1,n+1):
if str((i+j)**2)==str(i)+str(j):
solutions.append("("+str(i)+"+"+str(j)+")^2 = "+str((i+j)**2))
return solutions
print('\n'.join(using_numpy(n)))
# print('\n'.join(using_str(n)))
产量
(8+1)^2 = 81
(20+25)^2 = 2025
(30+25)^2 = 3025
(88+209)^2 = 88209
(494+209)^2 = 494209
对于n = 1000
,using_numpy
比using_str
快约16倍:
In [455]: %timeit using_str(n)
500 ms ± 251 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [470]: %timeit using_numpy(n)
31.1 ms ± 98.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)