Rust LAPACK 与 Numpy 的 QR 因式分解不一致

问题描述 投票:0回答:0

我有一个使用 LAPACK 计算 QR 分解的 Rust 片段,但我的代码和 Numpy 之间得到不同的答案。但是当我查看文档并比较两个实现时,我不清楚我做错了什么。

Numpy 代码:

import numpy as np
A = np.arange(1,25).reshape((6,4))
np.linalg.qr(A, mode="raw")[0]

防锈代码:

        let m = 6 as i32;
        let n = 4 as i32;
        let mut data = vec![1.0, 5.0,  9.0, 13.0, 17.0, 21.0,
                            2.0, 6.0, 10.0, 14.0, 18.0, 22.0,
                            3.0, 7.0, 11.0, 15.0, 19.0, 23.0,
                            4.0, 8.0, 12.0, 16.0, 20.0, 24.0];
        let buf = data.as_mut_slice();

        let tau_size = min(m, n) as usize;
        let mut tau = vec![0.0; tau_size];

        let lwork: i32 = -1;
        let mut info: i32 = 0;
        let mut work = vec![0.0; 1];

        // Get the worksize
        unsafe {
            dgeqrf(m, n, buf, m, tau.as_mut_slice(), work.as_mut_slice(), lwork, &mut info);
        }
        if info != 0 {
            return Err(Error::Arrow("failed to get work size for matrix".to_string()));
        }

        let lwork = work[0] as i32;
        let mut work = vec![0.0; lwork as usize];

        // Compute QR
        unsafe {
            dgeqrf(m, n, buf, m, tau.as_mut_slice(), work.as_mut_slice(), lwork, &mut info);
        }
        if info != 0 {
            return Err(Error::Arrow("failed to compute QR (dgeqrf) for matrix".to_string()));
        }

        println!("Raw {:?}", data);

右下 [2x4] 矩形的输出矩阵不同

Numpy 输出:

array([[-31.7, 0.153, 0.275, 0.397, 0.52, 0.642],
       [-33.8, -1.29, 0.0837, -0.123, -0.33, -0.537],
       [-35.9, -2.58, 1.29e-15, -0.2, 0.263, 0.397],
       [-38, -3.88, 2.87e-15, 5.22e-16, -0.106, 0.705]])

Rust 输出:

array([[-31.7, 0.153, 0.275, 0.397, 0.52, 0.642],
       [-33.8, -1.29, 0.0837, -0.123, -0.33, -0.537],
       [-35.9, -2.58, 1.14e-15, -0.631, -0.172, 0.484],
       [-38, -3.88, 2.22e-15, 1.55e-16, 0.541, 0.327]])

感觉我只是错过了一些明显的东西。感谢您的指点!

我查看了 numpy 代码、netlib 上的 LAPACK 文档等

numpy rust scipy linear-algebra lapack
© www.soinside.com 2019 - 2024. All rights reserved.