Python 中使用多处理池方法的矩阵乘法

问题描述 投票:0回答:1

使用多处理池方法进行矩阵乘法。

import numpy as np
from multiprocessing import Pool

def matrix_multiply(args):
    matrix_a, matrix_b = args
    return np.dot(matrix_a, matrix_b)

if __name__ == "__main__":
    matrix_size = 4
    matrix_a = np.random.rand(matrix_size, matrix_size)
    matrix_b = np.random.rand(matrix_size, matrix_size)
    print(matrix_a)
    print(matrix_b)

    num_processes = 2
    chunk_size = matrix_size // num_processes
    matrix_chunks_a = [matrix_a[i * chunk_size: (i + 1) * chunk_size] for 
      i in range(num_processes)]
    print('Durga1a',matrix_chunks_a)
    matrix_chunks_b = [matrix_b[i * chunk_size: (i + 1) * chunk_size] for 
     i in range(num_processes)]
    print('Durga2',matrix_chunks_b)


   with Pool(num_processes) as pool:
   result_chunks = pool.starmap(matrix_multiply, zip(matrix_chunks_a, 
       matrix_chunks_b))
   print('Durga3')

   result_matrix = np.vstack(result_chunks)
   print('Durga4')
   print("Matrix multiplication result:")
   print(result_matrix)

pool.starmap() 的输出错误:TypeError:matrix_multiply() 需要 1 个位置参数,但给出了 2 个。 pool.map() 的输出错误:ValueError:形状 (2,4) 和 (2,4) 未对齐:4 (dim 1) != 2 (dim 0)

python numpy error-handling python-multiprocessing matrix-multiplication
1个回答
0
投票

第23行: 尝试使用“地图”功能

result_chunks = pool.map(multiply_matrices,zip(matrix_chunks_a,matrix_chunks_b))

另一个问题是第一个矩阵中的列数与第二个矩阵中的行数不匹配

© www.soinside.com 2019 - 2024. All rights reserved.