在 Rust 中线程完成后从闭包中返回所有权

问题描述 投票:0回答:1

我正在学习 Rust,现在正在研究线程。我正在尝试将以下代码片段转换为并行执行:

fn main() {
    let mut data: Vec<usize> = (1..22).collect();

    for d in &mut data {
        *d *= 99;
    }
}

由于每个元素的处理完全独立于其他元素,因此应该很容易并行化。所以我编写了以下代码,将向量分成

N_THREADS
部分,然后将每个部分交给线程进行处理:

use std::thread;

const N_THREADS: usize = 8;
const N_ELEMENTS: usize = 2*N_THREADS;

fn main() {
    let mut data: Vec<usize> = (1..N_ELEMENTS).collect();

    let mut threads = vec![];
    for n_thread in 0..N_THREADS {
        let thread_fn = move || {
            for d in &mut data[n_thread*(N_ELEMENTS/N_THREADS)..(n_thread+1)*(N_ELEMENTS/N_THREADS)] {
                *d *= 99
            }
        };
        let thread_handle = thread::spawn(thread_fn);
        threads.push(thread_handle);
    }
    // Now wait for the threads to finish.
    for t in threads {
        t.join().unwrap();
    } // I need each `t` to return the ownership of its respective piece of `data` here!
    
    do_something_else(data);
}

如何在所有线程完成后重新获得

data
的所有权,这样我就可以
do_something_else(data)

multithreading rust ownership
1个回答
1
投票

暂且不谈,只使用 Rayon,它可以为你做到这一点。

所以我编写了以下代码,将向量分成 N_THREADS 个部分,然后将每个部分交给一个线程来进行处理:

您是否尝试过编译您的代码?因为不可行,所以在第一次迭代时

data
会移入您创建的第一个线程,然后第二个线程尝试获取移出值,这是非法的。

即使that有效,你也不能

Index
那样输出多个可变切片,编译器无法知道这些切片是不重叠的。这就是为什么切片有一堆
split*_mut
方法来分割可变借用。常规的非作用域线程也不能借用它们的词法作用域。

你可以做的是

Vec::split_off
将源向量转换为子向量,这些子向量可以移动到工作人员中,尽管这会为每个工作人员分配一个向量。

如何在所有线程完成后拿回数据的所有权,以便我可以 do_something_else(data)?

好吧

join
将返回线程回调返回的任何内容,因此您可以使用它,让每个线程返回其(更新的)子向量,然后将向量连接回来。

一个更好的选择(再次忽略人造丝)是绕过整个事情,使用范围线程,并借用原来的。这需要使用某种形式的 mut 分割来为每个线程提供一个可变的非重叠切片:

use std::thread;

const N_THREADS: usize = 8;
const PER_THREAD: usize = 2;
const N_ELEMENTS: usize = PER_THREAD*N_THREADS;

fn main() {
    let mut data: Vec<usize> = (1..N_ELEMENTS).collect();
    
    println!("{:?}", data);
    thread::scope(|s| {
        for t in data.chunks_mut(PER_THREAD) {
            s.spawn(move || {
                for d in t {
                    *d *= 99;
                }
            });
        }
        // no need to join by hand, because `scope` implicitly joins the scoped threads and we work in-place
    });
    println!("{:?}", data);
}

PS:请注意,

data
不是
N_ELEMENTS
长,而是
N_ELEMENTS-1
(又名它只有15个元素):Rust的正常范围是半开的,所以
1..N_ELEMENTS
从1到15而不是1到16。

PPS:这是用 Rayon 代替的代码:

use rayon::prelude::*;

const N_THREADS: usize = 8;
const PER_THREAD: usize = 2;
const N_ELEMENTS: usize = PER_THREAD*N_THREADS;

fn main() {
    let mut data: Vec<usize> = (1..N_ELEMENTS).collect();
    
    println!("{:?}", data);
    data.par_iter_mut().for_each(|d| *d *= 99);
    println!("{:?}", data);
}
© www.soinside.com 2019 - 2024. All rights reserved.