更有效的方法来强制找到解决方案(x + y)^ 2 = str(x)+ str(y)?它可以被矢量化吗?

问题描述 投票:0回答:2

到目前为止我写了:

n=1000
solutions=[]
for i in range(1,n+1):
    for j in range(1,n+1):
        if str((i+j)**2)==str(i)+str(j):
            solutions.append("("+str(i)+"+"+str(j)+")^2 = "+str((i+j)**2))
for solution in solutions:
    print(solution)

这在我的电脑上需要1.03秒。有没有更快的方法来实施比较?我在vectorisation上找到了一个页面,但我不确定如何生成列表,然后我需要进行矢量化比较。

python-3.x optimization comparison string-comparison
2个回答
1
投票

通过搜索满足目标范围内给定方格的方程的(x, y)对,可以更快地完成此操作。事实上,这减少了从O(n^2)O(nlogn)时间复杂度的问题。

def split_root(n):
    div = 10
    while div < n:
        x, y = divmod(n, div)
        div *= 10
        if not y or y < div // 100: continue
        if (x + y) ** 2 == n: yield x, y

然后迭代可能的方块:

def squares(n):
    for i in range(n):
        for sr in split_root(i ** 2):
            yield "({}+{})^2 = {}".format(*sr, sum(sr)**2)

用法示例:

print("\n".join(squares(100000)))

输出:

(8+1)^2 = 81
(20+25)^2 = 2025
(30+25)^2 = 3025
(88+209)^2 = 88209
(494+209)^2 = 494209
(494+1729)^2 = 4941729
(744+1984)^2 = 7441984
(2450+2500)^2 = 24502500
(2550+2500)^2 = 25502500
(5288+1984)^2 = 52881984
(6048+1729)^2 = 60481729
(3008+14336)^2 = 300814336
(4938+17284)^2 = 493817284
(60494+17284)^2 = 6049417284
(68320+14336)^2 = 6832014336

为了比较,您的原始解决方案 -

def op_solver(n):
    solutions = []
    for i in range(1,n+1):
        for j in range(1,n+1):
            if str((i+j)**2)==str(i)+str(j):
                solutions.append("("+str(i)+"+"+str(j)+")^2 = "+str((i+j)**2))
    return solutions


>>> timeit("op_solver(1000)", setup="from __main__ import op_solver", number=5) / 5
0.8715057126013562

我的解决方案

>>> timeit("list(squares(2000))", setup="from __main__ import squares", number=100) / 100
0.006898956680088304

对于您的示例使用范围,大约需要125倍的加速,并且随着n的增长,它会渐渐快速地运行。

这也有利于比numpy解决方案更快更简单,当然不需要numpy。如果你确实需要更快的版本,我相信你甚至可以对我的代码进行矢量化以获得两全其美的效果。


1
投票

您可以通过避免字符串操作来加快计算速度。而不是连接字符串,使用i * 10**(int(math.log10(j))+1) + j以数字方式“连接”:

In [457]: i, j = 20, 25; i * 10**(int(math.log10(j))+1) + j
Out[457]: 2025

您还可以使用NumPy来矢量化计算:

import numpy as np
n = 1000

def using_numpy(n):
    i = range(1, n+1)
    j = range(1, n+1)
    I, J = np.meshgrid(i, j)

    left = (I+J)**2
    j_digits = np.log10(J).astype(int) + 1
    right = I*10**j_digits + J
    mask = left == right
    solutions = ['({i}+{j})^2 = {k}'.format(i=i, j=j, k=k)
                 for i, j, k in zip(I[mask], J[mask], left[mask])]
    return solutions

def using_str(n):
    solutions=[]
    for i in range(1,n+1):
        for j in range(1,n+1):
            if str((i+j)**2)==str(i)+str(j):
                solutions.append("("+str(i)+"+"+str(j)+")^2 = "+str((i+j)**2))
    return solutions

print('\n'.join(using_numpy(n)))
# print('\n'.join(using_str(n)))

产量

(8+1)^2 = 81
(20+25)^2 = 2025
(30+25)^2 = 3025
(88+209)^2 = 88209
(494+209)^2 = 494209

对于n = 1000using_numpyusing_str快约16倍:

In [455]: %timeit using_str(n)
500 ms ± 251 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [470]: %timeit using_numpy(n)
31.1 ms ± 98.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
最新问题
© www.soinside.com 2019 - 2024. All rights reserved.