我在使用以下代码时遇到问题:
测试代码.py
from mpi4py import MPI
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
def classify_data(classifier, data):
X_train, y_train, X_test = data
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
return y_pred
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
comm.bcast((X_train, y_train), root=0)
classifiers = {
0: RandomForestClassifier(n_estimators=100),
1: DecisionTreeClassifier()
}
classifier = classifiers.get(rank, None)
if classifier is not None:
y_pred = classify_data(classifier, (X_train, y_train, X_test))
comm.Barrier()
if rank == 0:
all_predictions = comm.gather(y_pred, root=0)
我正在使用控制台执行代码:
mpiexec -n 2 python -m mpi4py TestCode.py
我发现如果我尝试执行代码,计算时就会遇到麻烦
rank == 0
为什么会这样?
感谢@Giles-gouaillardet 的上述建议。事实上,通过为所有任务调用下面的函数解决了整个问题。
all_predictions = comm.gather(y_pred, root=0)
我又发现了一个“小”问题,其功能是:
comm.bcast((X_train, y_train), root=0)
在我的实现中甚至不需要这行代码,因为代码已经为每个任务分配了值。因此,为了让它工作,我分配的值如下:
if rank == 0:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
else:
X_train, X_test, y_train, y_test = None, None, None, None
X_train = comm.bcast(X_train, root=0)
...