我正在使用 mpi4py 在 python 中并行化我的代码。 comm.Gather 自动将所有核心的 numpy 数组收集到一个数组中。我已经使用了它并且它有效,但是当我尝试收集不同形状的数组(可以很容易地连接)时,代码失败了。我的代码:
from mpi4py import MPI
import numpy as np
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
a = np.zeros((2 if rank==1 else 5, 3),dtype=float)+rank #size=3: shape[0] is 5,2,5 for ranks 0,1,2
#a = np.zeros((6,3),dtype=float)+rank #size=2: shape[0] is 6,6 for ranks 0,1 (this works)
print(rank,a)
b = np.zeros((12,3),dtype=float)-1
comm.Gather(a,b,root=0)
if rank==0:
print(b)
我现在的问题是如何编写代码来连接不同形状的数组以使其工作。我不想要的是拆分数组,以便每个数组都具有相同的形状,然后 comm.send 手动发送其余部分并手动连接。
您可以使用
comm.gather
(小写)而不是 comm.Gather
来获取可以连接的数组元组(这适用于通用 Python 对象)。请参阅此处:https://mpi4py.readthedocs.io/en/stable/tutorial.html
import numpy as np
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
a = np.zeros((2 if rank == 1 else 5, 3), dtype=float) + rank
print(rank, a, a.shape)
b = comm.gather(a, root=0)
if rank == 0:
b = np.concatenate(b)
print(rank, b)
另一种方法是使用
comm.Gatherv
直接收集到 numpy 数组中:
import numpy as np
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
assert size <= 2
if rank == 0:
a = np.zeros((5, 3), dtype=float)
else:
a = np.zeros((2, 3), dtype=float) + rank
print(rank, a, a.shape)
# number of global rows
n_global = 7
if rank == 0:
b = np.zeros((n_global, a.shape[1]), dtype=float)
else:
b = None
comm.Gatherv(a, (b, (5 * 3, 2 * 3), (0, 5 * 3), MPI.DOUBLE), root=0)
if rank == 0:
print(b)
Gatherv
的接收缓冲区由元组(b, (5 * 3, 2 * 3), (0, 5 * 3), MPI.DOUBLE)
指定,其中
b
中的索引,在其后插入数组(0 和 5 * 3:在索引 0 处插入 Rank 0 数组,在索引 15 处插入 Rank 1 数组,即在最后一个元素之后第一个数组)我发现这很有帮助:https://materials.jeremybejarano.com/MPIwithPython/collectiveCom.html