使用 OAEP 实现 RSA 偶尔会产生 lhash 和 lhash 错误'

问题描述 投票:0回答:1

我尝试在Python 3.12中实现RSA算法。我首先实现了教科书上的RSA算法,并成功完成。 (我使用各种密钥和消息进行了多次重复尝试,验证了这一点,所有这些都成功加密和解密。)

但是,由于我实现了最佳非对称加密填充(OAEP)编码和解码,在 OAEP 解码过程中,lHash 偶尔不等于 lHash'。 (我使用的是 OAEP 上维基百科页面的术语:https://en.wikipedia.org/wiki/Optimal_ametry_encryption_padding)。快速浏览一下我的代码就会发现,许多 OAEP 相关功能都是“受到”以下 github 项目的“启发”或直接取自以下 github 项目:https://gist.github.com/ppoffice/e10e0a418d5dafdd5efe9495e962d3d2

大约一半的时间,线路

assert(lhash == lhash_prime)
会导致断言错误。当运行程序 100 次时,46% 的时间会出现断言错误。我有一些有效和无效的关键值示例。

我尝试分析 RSA 密钥生成产生的 n、e 和 d 值中的模式,这些模式会引发断言错误。我相信,鉴于 OAEP 使用 n 的长度作为过程的重要部分,n 值尤其会有所帮助。然而,对于如此大的值,像我这样的初学者程序员很难理解它们。

当我实际上没有使用教科书 RSA 和 OAEP 来加密和解密消息,而是仅使用 OAEP 来编码和解码消息时,该过程运行良好。此外,我仅使用教科书 RSA 所做的任何测试也有效。

下面是一个最小的、可重现的示例的代码。抱歉,但即使我已尽力减少它的长度(删除了检查以确保素数 p 和 q 实际上是安全素数等),如果要重现该错误,它仍然很长.

import random
from math import ceil
import hashlib
import os
from typing import Callable

def byte_len(n: int) -> int:
    return ceil(n.bit_length() / 8)

def get_n_bit_rand_num(n: int) -> int:
    return random.randrange(2**(n-1)+1,2**n-1)

def rabin_miller_composite_test(a: int, m: int, k: int, n: int) -> bool:
    if (pow(a,m,n) == 1): 
        return False 
    for i in range(k):
        if (pow(a,2**i*m,n) == n-1):
            return False
    return True

def probablistic_is_prime_test(n: int) -> bool:
    first_primes_list = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29,
                     31, 37, 41, 43, 47, 53, 59, 61, 67,
                     71, 73, 79, 83, 89, 97, 101, 103,
                     107, 109, 113, 127, 131, 137, 139,
                     149, 151, 157, 163, 167, 173, 179,
                     181, 191, 193, 197, 199, 211, 223,
                     227, 229, 233, 239, 241, 251, 257,
                     263, 269, 271, 277, 281, 283, 293,
                     307, 311, 313, 317, 331, 337, 347, 349]
    for divisor in first_primes_list:
        if n % divisor == 0:
            return False
    k = 0
    m = n-1
    while (m % 2 == 0):
        m >>= 1 
        k += 1

    iterations = 20 
    for _ in range(iterations):
        a = random.randrange(2,n-1) 
        if rabin_miller_composite_test(a,m,k,n):
            return False
    return True

def get_random_large_prime() -> int:
    num_is_prime = False
    num = 0
    while(num_is_prime == False):
        num = get_n_bit_rand_num(1024)
        if probablistic_is_prime_test(num):
            num_is_prime = True
    return num

def euclidean_algorithm_GCD(larger_num: int, smaller_num: int) -> int:
    if (smaller_num == 0):
        return larger_num
    else:
        return euclidean_algorithm_GCD(smaller_num,larger_num % smaller_num)

def extended_euclidean_algorithm_second_num_of_linear_combination(larger_num: int, smaller_num: int) -> int:
    s = 0
    r = smaller_num
    old_r = larger_num
    old_s = 1
    quotient = 0
    temp = 0
    while (r != 0):
        quotient = old_r // r
        temp = old_r
        old_r = r
        r = temp - quotient * r
        temp = old_s
        old_s = s
        s = temp - quotient * s
    second_num = (old_r - old_s * larger_num) // smaller_num
    return second_num

def generate_keys() -> tuple[int, int, int]:
    p = get_random_large_prime()
    q = get_random_large_prime()
    n = p*q
    phi_of_n = (p-1) * (q-1)
    e = 65537
    while euclidean_algorithm_GCD(phi_of_n,e) != 1:
        e += 1
    d = extended_euclidean_algorithm_second_num_of_linear_combination(phi_of_n,e) % phi_of_n
    return n, e, d

def textbook_encrypt_message(message: bytes, e: int, n: int) -> int:
    int_message = int.from_bytes(message, 'little')
    return pow(int_message,e,n)

def textbook_decrypt_message(encrypted_message: int, d: int, n: int) -> bytes:
    int_message = pow(encrypted_message, d, n)
    return int_message.to_bytes(byte_len(int_message), 'little')

def encrypt_message_oaep(message: bytes, e: int, n: int) -> int:
    n_byte_length = byte_len(n)
    padded_message = oaep_encode(message,n_byte_length)
    return textbook_encrypt_message(padded_message,e,n)

def decrypt_message_oaep(encrypted_message: int, d: int, n: int) -> str:
    encoded_message = textbook_decrypt_message(encrypted_message, d, n)
    encoded_message_as_bytes = encoded_message
    n_byte_length = byte_len(n)
    message = oaep_decode(encoded_message_as_bytes,n_byte_length)
    return message.decode()

def bytewise_xor(data: bytes, mask: bytes) -> bytes: 
    masked = b""
    for i in range(max(len(data),len(mask))):
        if i < len(data) and i < len(mask):
            masked += (data[i] ^ mask[i]).to_bytes(1, byteorder = 'big')
        elif i < len(data):
            masked += data[i].to_bytes(1, byteorder="big")
        else:
            break
    return masked

def sha1(m: bytes) -> bytes:
    '''SHA-1 hash function'''
    hasher = hashlib.sha1()
    hasher.update(m)
    return hasher.digest()

def mgf1(seed: bytes, mlen: int, f_hash: Callable = sha1) -> bytes: 
    '''MGF1 mask generation function with SHA-1'''
    t = b''
    hlen = len(f_hash(b''))
    for c in range(0, ceil(mlen / hlen)):
        _c = c.to_bytes(4, byteorder="big")
        t += f_hash(seed + _c)
    return t[:mlen]

def oaep_encode(message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes: 
    lhash = hash_func(label)
    padding_string = (k - len(message)-2*len(lhash)-2) * b"\x00"
    data_block = lhash + padding_string + b"\x01" + message
    seed = os.urandom(len(lhash))
    data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
    masked_data_block = bytewise_xor(data_block,data_block_mask)
    seed_mask = mgf(masked_data_block,len(lhash),hash_func)
    masked_seed = bytewise_xor(seed,seed_mask)
    return b"\x00" + masked_seed + masked_data_block

def oaep_decode(encoded_message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes:
    lhash = hash_func(label)
    masked_seed = encoded_message[1:1 + len(lhash)]
    masked_data_block = encoded_message[1+len(lhash):]
    seed_mask = mgf(masked_data_block,len(lhash),hash_func)
    seed = bytewise_xor(masked_seed,seed_mask)
    data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
    data_block = bytewise_xor(masked_data_block, data_block_mask)
    lhash_prime = data_block[:len(lhash)]
    assert(lhash == lhash_prime)
    i = len(lhash)
    while i < len(data_block):
        if data_block[i] == 0:
            i += 1
            continue
        elif data_block[i] == 1:
            i += 1
            break
        else:
            raise Exception('This should never happen.')
    return data_block[i:]

[n, e, d] = generate_keys()
print("n: ", n)
print("e: ", e)
print("d: ", d)
message = "Imagine that this is some secure test message"
oaep_encrypted_message = encrypt_message_oaep(message.encode(), e, n)
print(oaep_encrypted_message)
print(decrypt_message_oaep(oaep_encrypted_message, d, n))
python cryptography rsa oaep
1个回答
0
投票

因此,经过一番测试,可以确定程序无法正常运行的原因是OAEP的实现偶尔会导致填充消息,其整数表示实际上会大于n的值。由于 RSA 在加密过程中取 m^e 模 n,这显然会导致解密的填充消息与加密的填充消息不同。

通过向 oaep_encode 函数添加 while(True) 和 if 条件,可以轻松解决该问题:

def oaep_encode(message: bytes, n: int, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes: 
    lhash = hash_func(label)
    padding_string = (k - len(message)-2*len(lhash)-2) * b"\x00"
    data_block = lhash + padding_string + b"\x01" + message
    while(True):
        seed = os.urandom(len(lhash))
        data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
        masked_data_block = bytewise_xor(data_block,data_block_mask)
        seed_mask = mgf(masked_data_block,len(lhash),hash_func)
        masked_seed = bytewise_xor(seed,seed_mask)
        encoded_message = b"\x00" + masked_seed + masked_data_block
        if (int.from_bytes(encoded_message,'little') < n):
            return encoded_message

这可能不是解决该问题的最优雅的解决方案,但是,根据我遇到的各种来源描述的 OAEP 实现,它们似乎都没有确保填充消息的整数值不大于n 的值。 (https://en.wikipedia.org/wiki/Optimal_ametry_encryption_paddinghttps://www.cs.rit.edu/~spr/COURSES/CRYPTO/oaep.pdfhttps://gist.github .com/ppoffice/e10e0a418d5dafdd5efe9495e962d3d2

© www.soinside.com 2019 - 2024. All rights reserved.