当我研究在 32 位处理器上高效实现 MRG32k3a PRNG 时,出现了这个问题,该处理器不支持
double
计算,或者速度很慢。我对 ARM、RISC-V 和 GPU 特别感兴趣。 MRG32k3a 是一种非常高质量的 PRNG,因此至今仍在广泛使用,尽管它可以追溯到 20 世纪 90 年代末:
P。 L'Ecuyer,“组合多个递归随机数生成器的良好参数和实现。” 运筹学,卷。 47号,1月-2月1999 年,第 159-164 页
MRG32k3a 组合了两个形式为 (c0 ⋅ state0 - c1 ⋅ state1) mod m 的递归序列,其中 statei < m。 MRG32k3a 中的常量和状态变量均为 32 位的正整数,计算中所有中间表达式的量值均小于 253。这是设计使然,因为参考实现使用 IEEE-754
double
进行存储和计算。数学模数 mod 与 ISO-C 的 %
运算符不同,它始终提供非负结果。下面代码中的第一个变体显示了使用 double
稍微现代化的参考实现。
在 MRG32k3a 的纯整数实现中,32 位变量用于状态组件,中间计算以 64 位算术执行。通过确保被除数为非负数,可以通过
%
轻松计算模: (c0 ⋅ state0 - c1 ⋅ state1) mod m = (c0 ⋅ state0 - c1 ⋅ state1 + c1 ⋅ m) % m. 计算 % m
在 32 位处理器上非常昂贵,通常会导致库调用。这可以通过标准除以常数优化轻松解决,其中 64 位高乘计算是最昂贵的部分(请参阅下面代码中的 GENERIC_MOD=1
变体)。
当被除数的大小受到限制并且 m = 2n-d 且 d 较小时,甚至可以实现更快的模计算。一组 lo = x % 2n,hi= x / 2^n,t = hi * d + lo。只要 t < 2 ⋅ m, x mod m = (t >= m) ? (t - m):t。当 x < 2n+(n-ceil(log2(d+1))) 时,所需条件成立。这对于 MRG32k3a 使用的第一次递归非常有效,其中 n=32 和 d=209 需要 x < 256,这很容易满足。但对于第二次循环 n=32 且 d = 22853,需要 x < 249。应用偏移量 c1 ⋅ m 以确保 x 为正值后,在本例中,x 可以大至 8.15 ⋅ 1015,仅略小于 253 ≈ 9 ⋅ 1015.
我目前正在通过基于添加的偏移量来解决这个问题,以确保在计算
x
之前根据状态变量的值确保 x % m
为正值,并且这会保留 x < 244。但从下面提取的相关代码行可以看出,这是一种相当昂贵的方法,其中包括 32 位除法(具有常数除数,因此可以优化,但仍然会产生不希望的成本)。
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
是否有替代且成本较低的缓解策略可以为 MRG32k3a 使用的第二次循环提供更有效的模计算?
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>
#define BUILTIN_64BIT (0)
#define GENERIC_MOD (0) // applies ony when BUILTIN_64BIT == 0
static double MRG32k3a_s10, MRG32k3a_s11, MRG32k3a_s12;
static double MRG32k3a_s20, MRG32k3a_s21, MRG32k3a_s22;
/* SIMD vectorized by Clang with -ffp-model=precise on x86-84 and AArch64
SIMD vectorized by Intel compiler with -fp-model=precise -march=core-avx2
*/
double MRG32k3a (void)
{
const double norm = 2.328306549295728e-10;
const double m1 = 4294967087.0;
const double m2 = 4294944443.0;
const double a12 = 1403580.0;
const double a13n = 810728.0;
const double a21 = 527612.0;
const double a23n = 1370589.0;
double k, p1, p2;
/* Component 1 */
p1 = a12 * MRG32k3a_s11 - a13n * MRG32k3a_s10;
k = floor (p1 / m1);
p1 -= k * m1;
MRG32k3a_s10 = MRG32k3a_s11; MRG32k3a_s11 = MRG32k3a_s12; MRG32k3a_s12 = p1;
/* Component 2 */
p2 = a21 * MRG32k3a_s22 - a23n * MRG32k3a_s20;
k = floor (p2 / m2);
p2 -= k * m2;
MRG32k3a_s20 = MRG32k3a_s21; MRG32k3a_s21 = MRG32k3a_s22; MRG32k3a_s22 = p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
static uint32_t MRG32k3a_s10i, MRG32k3a_s11i, MRG32k3a_s12i;
static uint32_t MRG32k3a_s20i, MRG32k3a_s21i, MRG32k3a_s22i;
#if BUILTIN_64BIT
double MRG32k3a_i (void)
{
const double norm = 2.328306549295728e-10;
const uint32_t m1 = 4294967087u;
const uint32_t m2 = 4294944443u;
const uint32_t a12 = 1403580u;
const uint32_t a13n = 810728u;
const uint32_t a21 = 527612u;
const uint32_t a23n = 1370589u;
uint64_t prod;
uint32_t p1, p2;
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
p1 = (uint32_t)(prod % m1);
MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
/* Component 2 */
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
p2 = (uint32_t)(prod % m2);
MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#elif GENERIC_MOD
uint64_t umul64hi (uint64_t a, uint64_t b)
{
uint32_t alo = (uint32_t)a;
uint32_t ahi = (uint32_t)(a >> 32);
uint32_t blo = (uint32_t)b;
uint32_t bhi = (uint32_t)(b >> 32);
uint64_t p0 = (uint64_t)alo * blo;
uint64_t p1 = (uint64_t)alo * bhi;
uint64_t p2 = (uint64_t)ahi * blo;
uint64_t p3 = (uint64_t)ahi * bhi;
return (p1 >> 32) + (((p0 >> 32) + (uint64_t)(uint32_t)p1 + p2) >> 32) + p3;
}
double MRG32k3a_i (void)
{
const double norm = 2.328306549295728e-10;
const uint32_t m1 = 4294967087u;
const uint32_t m2 = 4294944443u;
const uint32_t a12 = 1403580u;
const uint32_t a13n = 810728u;
const uint32_t a21 = 527612u;
const uint32_t a23n = 1370589u;
const uint32_t neg_m1 = 0 - m1; // 209
const uint32_t neg_m2 = 0 - m2; // 22853
const uint64_t magic_mul_m1 = 0x8000006880005551ull;
const uint64_t magic_mul_m2 = 0x4000165147c845ddull;
const uint32_t shft_m1 = 31;
const uint32_t shft_m2 = 30;
uint64_t prod;
uint32_t p1, p2;
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
p1 = (uint32_t)((umul64hi (prod, magic_mul_m1) >> shft_m1) * neg_m1 + prod);
MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
/* Component 2 */
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
p2 = (uint32_t)((umul64hi (prod, magic_mul_m2) >> shft_m2) * neg_m2 + prod);
MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#else // !BUILTIN_64BIT && !GENERIC_MOD --> special fast modulo computation
double MRG32k3a_i (void)
{
const double norm = 2.328306549295728e-10;
const uint32_t m1 = 4294967087u;
const uint32_t m2 = 4294944443u;
const uint32_t a12 = 1403580u;
const uint32_t a13n = 810728u;
const uint32_t a21 = 527612u;
const uint32_t a23n = 1370589u;
const uint32_t neg_m1 = 0 - m1; // 209
const uint32_t neg_m2 = 0 - m2; // 22853
uint64_t prod;
uint32_t p1, p2, prod_lo, prod_hi, adj;
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure its positive
// ! special modulo computation: prod must be < 2**56 !
prod_lo = (uint32_t)prod;
prod_hi = (uint32_t)(prod >> 32);
p1 = prod_hi * neg_m1 + prod_lo;
if ((p1 >= m1) || (p1 < prod_lo)) p1 += neg_m1;
MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
/* Component 2 */
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
// ! special modulo computation: prod must be < 2**49 !
prod_lo = (uint32_t)prod;
prod_hi = (uint32_t)(prod >> 32);
p2 = prod_hi * neg_m2 + prod_lo;
if ((p2 >= m2) || (p2 < prod_lo)) p2 += neg_m2;
MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#endif // BUILTIN_64BIT
/*
http://www.burtleburtle.net/bob/hash/doobs.html
By Bob Jenkins, 1996. [email protected]. You may use this
code any way you wish, private, educational, or commercial. It's free.
*/
#define mix(a,b,c) \
(a -= b, a -= c, a ^= (c>>13), \
b -= c, b -= a, b ^= (a<<8), \
c -= a, c -= b, c ^= (b>>13), \
a -= b, a -= c, a ^= (c>>12), \
b -= c, b -= a, b ^= (a<<16), \
c -= a, c -= b, c ^= (b>>5), \
a -= b, a -= c, a ^= (c>>3), \
b -= c, b -= a, b ^= (a<<10), \
c -= a, c -= b, c ^= (b>>15))
int main (void)
{
uint32_t m1 = 4294967087u;
uint32_t m2 = 4294944443u;
uint32_t a, b, c;
a = 3141592654u;
b = 2718281828u;
c = 10; MRG32k3a_s10 = MRG32k3a_s10i = (1u << 10) | (mix (a, b, c) % m1);
c = 11; MRG32k3a_s11 = MRG32k3a_s11i = (1u << 11) | (mix (a, b, c) % m1);
c = 12; MRG32k3a_s12 = MRG32k3a_s12i = (1u << 12) | (mix (a, b, c) % m1);
c = 20; MRG32k3a_s20 = MRG32k3a_s20i = (1u << 20) | (mix (a, b, c) % m2);
c = 21; MRG32k3a_s21 = MRG32k3a_s21i = (1u << 21) | (mix (a, b, c) % m2);
c = 22; MRG32k3a_s22 = MRG32k3a_s22i = (1u << 22) | (mix (a, b, c) % m2);
double res, ref;
uint64_t count = 0;
do {
res = MRG32k3a_i();
ref = MRG32k3a();
if (res != ref) {
printf("\ncount=%llu ref=%23.16e res=%23.16e\n", count, res, ref);
return EXIT_FAILURE;
}
count++;
if ((count & 0xfffffff) == 0) printf ("\rcount = %llu ", count);
} while (ref != 0);
return EXIT_SUCCESS;
}
我还没有对此进行测试或基准测试,但如果我正确理解了你在做什么,我认为另一种选择是添加 m2 的固定倍数并进行两轮减少。
prod = ((uint64_t)a21) * MRG32k3a_s22i + ((uint64_t)m2 << 22) - ((uint64_t)a23n) * MRG32k3a_s20i; // 54 bits
然后你可以省略下面两行。
adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
然后使用 64 位中间 p2,执行
prod_lo = (uint32_t)prod;
prod_hi = (uint32_t)(prod >> 32); // 22 bits
p2_64 = (uint64_t)prod_hi * neg_m2 + prod_lo; // 38 bits
prod_lo = (uint32_t)p2_64;
prod_hi = (uint32_t)(p2_64 >> 32);
p2 = prod_hi * neg_m2 + prod_lo;