Python 中的 MPI - 分类

问题描述 投票:0回答:1

我在使用以下代码时遇到问题:

测试代码.py

from mpi4py import MPI
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier

def classify_data(classifier, data):

    X_train, y_train, X_test = data
    classifier.fit(X_train, y_train)
    y_pred = classifier.predict(X_test)
    return y_pred


comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

iris = load_iris()
X, y = iris.data, iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

comm.bcast((X_train, y_train), root=0)
    
classifiers = {
    0: RandomForestClassifier(n_estimators=100),
    1: DecisionTreeClassifier()
}
   
classifier = classifiers.get(rank, None)

if classifier is not None:    
    y_pred = classify_data(classifier, (X_train, y_train, X_test))
    comm.Barrier()
    if rank == 0:
        all_predictions = comm.gather(y_pred, root=0)

我正在使用控制台执行代码:

mpiexec -n 2 python -m mpi4py TestCode.py

我发现如果我尝试执行代码,计算时就会遇到麻烦

rank == 0

为什么会这样?

python mpi
1个回答
0
投票

感谢@Giles-gouaillardet 的上述建议。事实上,通过为所有任务调用下面的函数解决了整个问题。

all_predictions = comm.gather(y_pred, root=0)

我又发现了一个“小”问题,其功能是:

comm.bcast((X_train, y_train), root=0)

在我的实现中甚至不需要这行代码,因为代码已经为每个任务分配了值。因此,为了让它工作,我分配的值如下:

if rank == 0:
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
else:
    X_train, X_test, y_train, y_test = None, None, None, None

X_train = comm.bcast(X_train, root=0)
...
© www.soinside.com 2019 - 2024. All rights reserved.