我正在尝试使用 snrm2 在 Rust 中执行单精度浮点计算。我链接到 OSX 上的 Accelerate 框架,并使用 blas crate 作为 C 桥。无论随机生成的向量值如何(在本机代码中通常计算为 [12, 13]),snrm2 逻辑始终返回:
当我将所有内容切换到 f64 和 dnrm2 时,相同的代码可以正常工作。这是 Accelerate 中的错误,还是我没有满足其有关内存对齐或调用参数的一些内部假设?
use rand::prelude::*;
use blas::*;
type Vector = [f32; 2048];
// Native implementation
fn euclidean_distance_standard(a: &Vector, b: &Vector) -> f32 {
a.iter().zip(b.iter()).fold(0.0, |acc, (&x, &y)| {
let diff = x - y;
acc + diff * diff
}).sqrt()
}
// Using snrm2 from BLAS
fn euclidean_distance_blas(a: &Vector, b: &Vector) -> f32 {
let mut diff = [0.0; 2048];
for i in 0..2048 {
diff[i] = a[i] - b[i];
}
unsafe {
snrm2(diff.len() as i32, &diff, 1)
}
}
fn main() {
let mut rng = rand::thread_rng();
// Initialize a random vector
let mut vector_a = [0.0; 2048];
for i in 0..2048 {
vector_a[i] = rng.gen::<f32>();
}
let vector_b = [0.5; 2048];
let result_standard = euclidean_distance_standard(&vector_a, &vector_b);
let result_blas = euclidean_distance_blas(&vector_a, &vector_b);
println!("Standard Method Result: {}", result_standard);
println!("OpenBLAS Method Result: {}", result_blas);
}
# build.rs
fn main() {
println!("cargo:rustc-link-lib=framework=Accelerate");
}
# Cargo.toml
[dependencies]
rand = "0.8"
blas = "0.21"
blas-src = { version = "0.9", features = ["accelerate"] }
libc = "0.2.36"
$ cargo run --release
Finished release [optimized] target(s) in 0.01s
Running `target/release/similarity-comp`
Standard Method Result: 13.005363
OpenBLAS Method Result: -36893490000000000000
这是一个错误,但可能在包装器中,我尝试了几种不同的方法来计算它。
第一个假设是错误位于
sgrm2
,然后我尝试使用 sdot 计算平方范数,并观察到相同的行为。
然后我考虑了另一个例程,返回
f32
,1-范数,sasum,那里也观察到了错误的结果。
然后我尝试了一个函数,它接收
f32
向量并返回一个整数 isamax,它工作得很好。
使用功能
accelerate
或 openblas
都会发生这种情况。
所以我的假设是包装器在返回时出现问题
f32
。接下来我要做的是选择一个 2 级函数 sgemv 来计算 y = alpha * A * x + beta * y
,设置 alpha=1.0
、beta=0.0
和 A = x'
,我们得到 y = x'*x
,即范数的平方。不同之处在于,这将被写入可变的 [f32]
参数,而不是由函数返回。在这种情况下,我们得到了预期的结果。
这是您可以在自己身边测试的完整程序
extern crate blas;
use blas::*;
fn main() {
let v1: [f32; 6] = [2.0, 0.0, 5.0, 4.0, 2.0, 0.0];
let mut y1: [f32; 1] = [0.0];
let v1d: [f64; 6] = [2.0, 0.0, 5.0, 4.0, -2.0, 0.0];
let mut y1d: [f64; 1] = [0.0];
let n = v1.len() as i32;
unsafe {
sgemv('N' as u8, 1, n, 1.0, &v1, 1, &v1, 1, 0.0, &mut y1, 1);
};
println!("snrm2={:.3}", unsafe { snrm2(v1.len() as i32, &v1, 1) });
println!("|x|_1 via sasum={:.3}", unsafe { sasum(v1.len() as i32, &v1, 1) });
println!("|x|^2 via sdot={:.3}", unsafe { sdot(v1.len() as i32, &v1, 1, &v1, 1) });
println!("argmax isamax={}", unsafe { isamax(v1.len() as i32, &v1, 1) });
println!("|x|^2 via sgemv={:.3}", y1[0]);
unsafe {
dgemv('N' as u8, 1, n, 1.0, &v1d, 1, &v1d, 1, 0.0, &mut y1d, 1);
};
println!("dnrm2={:.3}", unsafe { dnrm2(v1d.len() as i32, &v1d, 1) });
println!("|x|_1 via dasum={:.3}", unsafe { dasum(v1d.len() as i32, &v1d, 1) });
println!("|x|^2 via ddot={:.3}", unsafe { ddot(v1d.len() as i32, &v1d, 1, &v1d, 1) });
println!("argmax via idamax={}", unsafe { idamax(v1d.len() as i32, &v1d, 1) });
println!("|x|^2 via dgemv={:.3}", y1d[0]);
}
输入向量具有 2-范数 7(平方范数 49)和 1-范数 13。
对于所有 fp64 例程,程序输出都是正确的,对于 fp32 例程,只有 isamax 和 sgemv 的行为符合预期。
snrm2=0.000
|x|_1 via sasum=0.000
|x|^2 via sdot=0.000
argmax isamax=3
|x|^2 via sgemv=49.000
dnrm2=7.000
|x|_1 via dasum=13.000
|x|^2 via ddot=49.000
argmax via idamax=3
|x|^2 via dgemv=49.000