我最近开始实现后量子加密,需要实现多项式乘法,其中使用数论变换(ntt)来减少其计算时间,下面的代码是其中一种方法(不是最优的,但与朴素卷积实现),这是我在使用后量子 kyber 加密的数论变换(ntt)实现两个多项式乘法时使用的代码。我在实现以下代码时遇到此错误:
terminate called after throwing an instance of 'std::bad_alloc' what(): std::bad_alloc
#include <bits/stdc++.h>
using namespace std;
int MODULUS = 17;
int GEN = 13;
int pm(int base, int exp, int modulus) {
int result = 1;
base %= modulus;
while (exp > 0) {
if (exp % 2 == 1) {
result = (result * base) % modulus;
}
exp >>= 1;
base = (base * base) % modulus;
}
return result;
}
std::vector<int> cooley_tukey_ntt(const std::vector<int>& a, int gen = GEN, int modulus = MODULUS) {
if (a.size() == 1) {
return a;
}
std::vector<int> omegas(a.size());
omegas[0] = 1;
for (int i = 1; i < a.size(); ++i) {
omegas[i] = (omegas[i - 1] * gen) % modulus;
}
std::vector<int> even, odd;
for (int i = 0; i < a.size(); i += 2) {
even.push_back(a[i]);
odd.push_back(a[i + 1]);
}
std::vector<int> even_ntt = cooley_tukey_ntt(even, pm(gen, 2, modulus), modulus);
std::vector<int> odd_ntt = cooley_tukey_ntt(odd, pm(gen, 2, modulus), modulus);
std::vector<int> out(a.size());
for (int k = 0; k < a.size() / 2; ++k) {
int p = even_ntt[k];
int q = (omegas[k] * odd_ntt[k]) % modulus;
out[k] = (p + q) % modulus;
out[k + a.size() / 2] = (p - q + modulus) % modulus;
}
return out;
}
std::vector<int> cooley_tukey_intt(const std::vector<int>& a, int gen = GEN, int modulus = MODULUS) {
if (a.size() == 1) {
return a;
}
std::vector<int> omegas(a.size());
omegas[0] = 1;
for (int i = 1; i < a.size(); ++i) {
omegas[i] = (omegas[i - 1] * pm(gen, modulus - 2, modulus)) % modulus;
}
std::vector<int> even, odd;
for (int i = 0; i < a.size(); i += 2) {
even.push_back(a[i]);
odd.push_back(a[i + 1]);
}
std::vector<int> even_ntt = cooley_tukey_intt(even, pm(gen, 2, modulus), modulus);
std::vector<int> odd_ntt = cooley_tukey_intt(odd, pm(gen, 2, modulus), modulus);
std::vector<int> out(a.size());
int scaler = pm(a.size(), modulus - 2, modulus);
for (int k = 0; k < a.size() / 2; ++k) {
int p = even_ntt[k];
int q = (omegas[k] * odd_ntt[k]) % modulus;
out[k] = ((p + q)*scaler) % modulus;
out[k + a.size() / 2] = (((p - q + modulus))*scaler) % modulus;
}
return out;
}
vector<int> ntt_mul_nwc_attempt(vector<int>p, vector<int>q, int gen=GEN, int modulus=MODULUS){
int deg_d = p.size();
vector<int>pp=p;
vector<int>qq=q;
for(int i=0;i<deg_d;i++){pp.push_back(0);qq.push_back(0);}
vector<int>pp_ntt=cooley_tukey_ntt(pp);
vector<int>qq_ntt=cooley_tukey_ntt(qq);
vector<int>rr_ntt;
for(int i=0;i<pp.size();i++)
{
rr_ntt[i]=((pp_ntt[i]*qq_ntt[i])%modulus);
}
vector<int>rr=cooley_tukey_intt(rr_ntt);
for(int i=deg_d;i<rr.size();i++)
{
rr[i - deg_d] = (rr[i - deg_d] - rr[i]) % modulus;
rr[i] = 0;
}
rr.resize(deg_d);
return rr;
}
int main() {
vector<int>p = {1, 2, 3, 4};
vector<int>q = {1, 3, 5, 7};
vector<int>pq_nwc_attempt = ntt_mul_nwc_attempt(p, q);
for(auto i:pq_nwc_attempt)cout<<i<<" ";
return 0;
}
这绝对是错误的:
std::vector<int>rr_ntt;
for (int i = 0; i < pp.size(); i++)
{
rr_ntt[i] = ((pp_ntt[i] * qq_ntt[i]) % modulus);
}
rr_ntt
向量为空。在 i
循环中访问 rr_ntt
向量的索引 for
的任何尝试都是越界访问。
如果您使用
at()
而不是 []
来访问元素,则使用 这个示例 很容易看到错误,因为会抛出 std::out_of_range
异常。