我尝试在Python 3.12中实现RSA算法。我首先实现了教科书上的RSA算法,并成功完成。 (我使用各种密钥和消息进行了多次重复尝试,验证了这一点,所有这些都成功加密和解密。)
但是,由于我实现了最佳非对称加密填充(OAEP)编码和解码,在 OAEP 解码过程中,lHash 偶尔不等于 lHash'。 (我使用的是 OAEP 上维基百科页面的术语:https://en.wikipedia.org/wiki/Optimal_ametry_encryption_padding)。快速浏览一下我的代码就会发现,许多 OAEP 相关功能都是“受到”以下 github 项目的“启发”或直接取自以下 github 项目:https://gist.github.com/ppoffice/e10e0a418d5dafdd5efe9495e962d3d2。
大约一半的时间,线路
assert(lhash == lhash_prime)
会导致断言错误。当运行程序 100 次时,46% 的时间会出现断言错误。我有一些有效和无效的关键值示例。
我尝试分析 RSA 密钥生成产生的 n、e 和 d 值中的模式,这些模式会引发断言错误。我相信,鉴于 OAEP 使用 n 的长度作为过程的重要部分,n 值尤其会有所帮助。然而,对于如此大的值,像我这样的初学者程序员很难理解它们。
当我实际上没有使用教科书 RSA 和 OAEP 来加密和解密消息,而是仅使用 OAEP 来编码和解码消息时,该过程运行良好。此外,我仅使用教科书 RSA 所做的任何测试也有效。
下面是一个最小的、可重现的示例的代码。抱歉,但即使我已尽力减少它的长度(删除了检查以确保素数 p 和 q 实际上是安全素数等),如果要重现该错误,它仍然很长.
import random
from math import ceil
import hashlib
import os
from typing import Callable
def byte_len(n: int) -> int:
return ceil(n.bit_length() / 8)
def get_n_bit_rand_num(n: int) -> int:
return random.randrange(2**(n-1)+1,2**n-1)
def rabin_miller_composite_test(a: int, m: int, k: int, n: int) -> bool:
if (pow(a,m,n) == 1):
return False
for i in range(k):
if (pow(a,2**i*m,n) == n-1):
return False
return True
def probablistic_is_prime_test(n: int) -> bool:
first_primes_list = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29,
31, 37, 41, 43, 47, 53, 59, 61, 67,
71, 73, 79, 83, 89, 97, 101, 103,
107, 109, 113, 127, 131, 137, 139,
149, 151, 157, 163, 167, 173, 179,
181, 191, 193, 197, 199, 211, 223,
227, 229, 233, 239, 241, 251, 257,
263, 269, 271, 277, 281, 283, 293,
307, 311, 313, 317, 331, 337, 347, 349]
for divisor in first_primes_list:
if n % divisor == 0:
return False
k = 0
m = n-1
while (m % 2 == 0):
m >>= 1
k += 1
iterations = 20
for _ in range(iterations):
a = random.randrange(2,n-1)
if rabin_miller_composite_test(a,m,k,n):
return False
return True
def get_random_large_prime() -> int:
num_is_prime = False
num = 0
while(num_is_prime == False):
num = get_n_bit_rand_num(1024)
if probablistic_is_prime_test(num):
num_is_prime = True
return num
def euclidean_algorithm_GCD(larger_num: int, smaller_num: int) -> int:
if (smaller_num == 0):
return larger_num
else:
return euclidean_algorithm_GCD(smaller_num,larger_num % smaller_num)
def extended_euclidean_algorithm_second_num_of_linear_combination(larger_num: int, smaller_num: int) -> int:
s = 0
r = smaller_num
old_r = larger_num
old_s = 1
quotient = 0
temp = 0
while (r != 0):
quotient = old_r // r
temp = old_r
old_r = r
r = temp - quotient * r
temp = old_s
old_s = s
s = temp - quotient * s
second_num = (old_r - old_s * larger_num) // smaller_num
return second_num
def generate_keys() -> tuple[int, int, int]:
p = get_random_large_prime()
q = get_random_large_prime()
n = p*q
phi_of_n = (p-1) * (q-1)
e = 65537
while euclidean_algorithm_GCD(phi_of_n,e) != 1:
e += 1
d = extended_euclidean_algorithm_second_num_of_linear_combination(phi_of_n,e) % phi_of_n
return n, e, d
def textbook_encrypt_message(message: bytes, e: int, n: int) -> int:
int_message = int.from_bytes(message, 'little')
return pow(int_message,e,n)
def textbook_decrypt_message(encrypted_message: int, d: int, n: int) -> bytes:
int_message = pow(encrypted_message, d, n)
return int_message.to_bytes(byte_len(int_message), 'little')
def encrypt_message_oaep(message: bytes, e: int, n: int) -> int:
n_byte_length = byte_len(n)
padded_message = oaep_encode(message,n_byte_length)
return textbook_encrypt_message(padded_message,e,n)
def decrypt_message_oaep(encrypted_message: int, d: int, n: int) -> str:
encoded_message = textbook_decrypt_message(encrypted_message, d, n)
encoded_message_as_bytes = encoded_message
n_byte_length = byte_len(n)
message = oaep_decode(encoded_message_as_bytes,n_byte_length)
return message.decode()
def bytewise_xor(data: bytes, mask: bytes) -> bytes:
masked = b""
for i in range(max(len(data),len(mask))):
if i < len(data) and i < len(mask):
masked += (data[i] ^ mask[i]).to_bytes(1, byteorder = 'big')
elif i < len(data):
masked += data[i].to_bytes(1, byteorder="big")
else:
break
return masked
def sha1(m: bytes) -> bytes:
'''SHA-1 hash function'''
hasher = hashlib.sha1()
hasher.update(m)
return hasher.digest()
def mgf1(seed: bytes, mlen: int, f_hash: Callable = sha1) -> bytes:
'''MGF1 mask generation function with SHA-1'''
t = b''
hlen = len(f_hash(b''))
for c in range(0, ceil(mlen / hlen)):
_c = c.to_bytes(4, byteorder="big")
t += f_hash(seed + _c)
return t[:mlen]
def oaep_encode(message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes:
lhash = hash_func(label)
padding_string = (k - len(message)-2*len(lhash)-2) * b"\x00"
data_block = lhash + padding_string + b"\x01" + message
seed = os.urandom(len(lhash))
data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
masked_data_block = bytewise_xor(data_block,data_block_mask)
seed_mask = mgf(masked_data_block,len(lhash),hash_func)
masked_seed = bytewise_xor(seed,seed_mask)
return b"\x00" + masked_seed + masked_data_block
def oaep_decode(encoded_message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes:
lhash = hash_func(label)
masked_seed = encoded_message[1:1 + len(lhash)]
masked_data_block = encoded_message[1+len(lhash):]
seed_mask = mgf(masked_data_block,len(lhash),hash_func)
seed = bytewise_xor(masked_seed,seed_mask)
data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
data_block = bytewise_xor(masked_data_block, data_block_mask)
lhash_prime = data_block[:len(lhash)]
assert(lhash == lhash_prime)
i = len(lhash)
while i < len(data_block):
if data_block[i] == 0:
i += 1
continue
elif data_block[i] == 1:
i += 1
break
else:
raise Exception('This should never happen.')
return data_block[i:]
[n, e, d] = generate_keys()
print("n: ", n)
print("e: ", e)
print("d: ", d)
message = "Imagine that this is some secure test message"
oaep_encrypted_message = encrypt_message_oaep(message.encode(), e, n)
print(oaep_encrypted_message)
print(decrypt_message_oaep(oaep_encrypted_message, d, n))
因此,经过一番测试,可以确定程序无法正常运行的原因是OAEP的实现偶尔会导致填充消息,其整数表示实际上会大于n的值。由于 RSA 在加密过程中取 m^e 模 n,这显然会导致解密的填充消息与加密的填充消息不同。
通过向 oaep_encode 函数添加 while(True) 和 if 条件,可以轻松解决该问题:
def oaep_encode(message: bytes, n: int, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes:
lhash = hash_func(label)
padding_string = (k - len(message)-2*len(lhash)-2) * b"\x00"
data_block = lhash + padding_string + b"\x01" + message
while(True):
seed = os.urandom(len(lhash))
data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
masked_data_block = bytewise_xor(data_block,data_block_mask)
seed_mask = mgf(masked_data_block,len(lhash),hash_func)
masked_seed = bytewise_xor(seed,seed_mask)
encoded_message = b"\x00" + masked_seed + masked_data_block
if (int.from_bytes(encoded_message,'little') < n):
return encoded_message
这可能不是解决该问题的最优雅的解决方案,但是,根据我遇到的各种来源描述的 OAEP 实现,它们似乎都没有确保填充消息的整数值不大于n 的值。 (https://en.wikipedia.org/wiki/Optimal_ametry_encryption_padding,https://www.cs.rit.edu/~spr/COURSES/CRYPTO/oaep.pdf,https://gist.github .com/ppoffice/e10e0a418d5dafdd5efe9495e962d3d2)