在 Rust 中查找 comptime 中的 SIMD 行数

问题描述 投票:0回答:1

我为点积编写了以下 simd 友好的代码:

pub fn scalar_product_simd<T>(a: &[T], b: &[T]) -> T
where
    T: Mul<Output = T> + Sum + Copy + Add<Output = T>,
{
    const CHUNK_SIZE: usize = 4;
    assert!(a.len() >= CHUNK_SIZE && b.len() >= CHUNK_SIZE && a.len() == b.len());

    let mut i = 0;
    let mut acc = a
        .chunks_exact(CHUNK_SIZE)
        .zip(b.chunks_exact(CHUNK_SIZE))
        .map(|(aa, bb)| {
            i += CHUNK_SIZE;
            aa.iter().zip(bb).map(|(&x, &y)| x * y).sum()
        })
        .sum();

    // handle remaining elements
    acc = acc + scalar_product(&a[i..], &b[i..]);
    acc
}

如何根据编译参数和T类型的大小来设置CHUNK_SIZE值。例如,对于 T:f64 和 AVX2 CHUNK_SIZE 将等于 4 ? 也许有一种正确的方法可以使用 std::simd 重写此代码?

rust simd
1个回答
0
投票

我认为我不完全理解你的问题,但是如果你想根据

CHUNK_SIZE
的类型拥有不同的常量
T
值,并且你想要支持的
T
数量有限,你可以创建一个辅助特征。

trait SimdHelper {
    const CHUNK_SIZE: usize;
}

impl SimdHelper for f64 {
    const CHUNK_SIZE: usize = 4;
}

pub fn scalar_product_simd<T>(a: &[T], b: &[T]) -> T
where
    T: Mul<Output = T> + Sum + Copy + Add<Output = T>,
    T: SimdHelper,
{
    const CHUNK_SIZE: usize = <T as SimdHelper>::CHUNK_SIZE;

    ...
}
© www.soinside.com 2019 - 2024. All rights reserved.