两个矩阵的平铺乘法如何工作?

问题描述 投票:0回答:1

我的老师写道:

为矩阵乘积实现 CUDA 内核作为外积向量。在此版本中,每个 K 个线程块通过实现矩阵外积公式来计算大小为 KxK 的结果矩阵的一个正方形。内核使用共享内存来存储矩阵 A 和矩阵 B 中相应的列向量,并存储结果数组的相应片段。

据我从本文中了解到,他希望我使用向量外积进行矩阵乘法,其中还包含平铺。这就是我想出这个的原因。

假设我想使用

kxk=2x2
块来乘以以下矩阵:

A = B = [[ 1  2  3  4]
         [ 5  6  7  8]
         [ 9 10 11 12]
         [13 14 15 16]]

乘法结果将是:

平铺乘法是这样工作的吗?

或者,我错过了什么吗?

===========================================================================================
k   r   i   j           a[i][r]     b[r][j]       a[i][r]*b[r][j]  c[i][j]    
===========================================================================================           
2                       a[0][0]=1   b[0][0]=1     1*1=1            c[0][0]=1     
                        a[0][0]=1   b[0][1]=2     1*2=2            c[0][1]=2
                        a[1][0]=5   b[0][0]=1     5*1=5            c[1][0]=5
                        a[1][0]=5   b[0][1]=2     5*2=10           c[1][1]=10
—------------------------------------------------------------------------------------------              
                        a[0][1]=2   b[1][0]=5     2*5=10           c[0][0]=(1+10)=11 
                        a[0][1]=2   b[1][1]=6     2*6=12           c[0][1]=(2+12)=14
                        a[1][1]=6   b[1][0]=5     6*5=30           c[1][0]=(5+30)=35
                        a[1][1]=6   b[1][1]=6     6*6=36           c[1][1]=(10+36)=46
—------------------------------------------------------------------------------------------
2                       a[0][2]=3   b[2][0]=9     3*9=27           c[0][0]=(11+27)=38      
                        a[0][2]=3   b[2][1]=10    3*10=30          c[0][1]=(14+30)=44
                        a[1][2]=7   b[2][0]=9     7*9=63           c[1][0]=(35+63)=98
                        a[1][2]=7   b[2][1]=10    7*10=70          c[1][1]=(46+70)=116
—------------------------------------------------------------------------------------------              
                        a[0][3]=4   b[3][0]=13    4*13=52          c[0][0]=(38+52)=90 
                        a[0][3]=4   b[3][1]=14    4*14=56          c[0][1]=(44+56)=100
                        a[1][3]=8   b[3][0]=13    8*13=104         c[1][0]=(98+104)=202
                        a[1][3]=8   b[3][1]=14    8*14=112         c[1][1]=(116+112)=228
===========================================================================================
2                       a[0][0]=1   b[0][2]=3     1*3=3            c[0][2]=3     
                        a[0][0]=1   b[0][3]=4     1*4=4            c[0][3]=4
                        a[1][0]=5   b[0][2]=3     5*3=15           c[1][2]=15
                        a[1][0]=5   b[0][3]=4     5*4=20           c[1][3]=20
—------------------------------------------------------------------------------------------              
                        a[0][1]=2   b[1][2]=7     2*7=14           c[0][2]=(3+14)=17 
                        a[0][1]=2   b[1][3]=8     2*8=16           c[0][3]=(4+16)=20
                        a[1][1]=6   b[1][2]=7     6*7=42           c[1][2]=(15+42)=57
                        a[1][1]=6   b[1][3]=8     6*8=48           c[1][3]=(20+48)=68
—------------------------------------------------------------------------------------------
2                       a[0][2]=3   b[2][2]=11    3*11=33          c[0][2]=(17+33)=50      
                        a[0][2]=3   b[2][3]=12    3*12=36          c[0][3]=(20+36)=56
                        a[1][2]=7   b[2][2]=11    7*11=77          c[1][2]=(57+77)=134
                        a[1][2]=7   b[2][3]=12    7*12=84          c[1][3]=(68+84)=152
—------------------------------------------------------------------------------------------              
                        a[0][3]=4   b[3][2]=15    4*15=60          c[0][2]=(50+60)=110 
                        a[0][3]=4   b[3][3]=16    4*16=64          c[0][3]=(56+64)=120
                        a[1][3]=8   b[3][2]=15    8*15=120         c[1][2]=(134+120)=254
                        a[1][3]=8   b[3][3]=16    8*16=128         c[1][3]=(152+128)=280
===========================================================================================
2                       a[2][0]=9   b[0][0]=1     9*1=9             c[2][0]=(0+9)=9     
                        a[2][0]=9   b[0][1]=2     9*2=18            c[2][1]=(0+18)=18
                        a[3][0]=13  b[0][0]=1     13*1=13           c[3][0]=(0+13)=13
                        a[3][0]=13  b[0][1]=2     13*2=26           c[3][1]=(0+26)=26
—------------------------------------------------------------------------------------------  
                        a[2][1]=10  b[1][0]=5     10*5=50           c[2][0]=(9+50)=59     
                        a[2][1]=10  b[1][1]=6     10*6=60           c[2][1]=(18+60)=78
                        a[3][1]=14  b[1][0]=5     14*5=70           c[3][0]=(13+70)=83
                        a[3][1]=14  b[1][1]=6     14*6=84           c[3][1]=(26+84)=110
—------------------------------------------------------------------------------------------ 
2                       a[2][2]=11  b[2][0]=9     11*9=99           c[2][0]=(59+99)=158     
                        a[2][2]=11  b[2][1]=10    11*10=110         c[2][1]=(78+110)=198
                        a[3][2]=15  b[2][0]=9     15*9=135          c[3][0]=(83+135)=218
                        a[3][2]=15  b[2][1]=10    15*10=150         c[3][1]=(110+150)=260
—------------------------------------------------------------------------------------------
2                       a[2][3]=12  b[3][0]=13    12*13=156         c[2][0]=(158+156)=314     
                        a[2][3]=12  b[3][1]=14    12*14=168         c[2][1]=(188+168)=356 
                        a[3][3]=16  b[3][0]=13    16*13=208         c[3][0]=(218+208)=426
                        a[3][3]=16  b[3][1]=14    16*14=224         c[3][1]=(260+224)=484 
===========================================================================================
2                       a[2][0]=9   b[0][2]=3     9*3=27            c[2][2]=(0+27)=27     
                        a[2][0]=9   b[0][3]=4     9*4=36            c[2][3]=(0+36)=36
                        a[3][0]=13  b[0][2]=3     13*3=39           c[3][2]=(0+39)=39
                        a[3][0]=13  b[0][3]=4     13*4=52           c[3][3]=(0+52)=52
—------------------------------------------------------------------------------------------  
                        a[2][1]=10  b[1][2]=7     10*7=70           c[2][2]=(27+70)=97     
                        a[2][1]=10  b[1][3]=8     10*8=80           c[2][3]=(36+80)=116
                        a[3][1]=14  b[1][2]=7     14*7=98           c[3][2]=(39+98)=137
                        a[3][1]=14  b[1][3]=8     14*8=112          c[3][3]=(52+112)=164
—------------------------------------------------------------------------------------------ 
2                       a[2][2]=11  b[2][2]=11    11*11=121         c[2][2]=(97+121)=218     
                        a[2][2]=11  b[2][3]=12    11*12=132         c[2][3]=(116+132)=248
                        a[3][2]=15  b[2][2]=11    15*11=165         c[3][2]=(137+165)=302
                        a[3][2]=15  b[2][3]=12    15*12=180         c[3][3]=(164+180)=344
—------------------------------------------------------------------------------------------
2                       a[2][3]=12  b[3][2]=15    12*15=180         c[2][2]=(218+180)=398     
                        a[2][3]=12  b[3][3]=16    12*16=192         c[2][3]=(248+192)=440 
                        a[3][3]=16  b[3][2]=15    16*15=240         c[3][2]=(302+240)=542
                        a[3][3]=16  b[3][3]=16    16*16=256         c[3][3]=(344+256)=600   
===========================================================================================
matrix-multiplication
1个回答
0
投票

一般来说,“平铺”矩阵乘法意味着将矩阵点积重构为 block 矩阵的乘积,因此

A dot B
可以(作为一个示例)表示为:


+——--+—-——+     +——--+—-——+     +——————————---+—-—————————-—+
| A1 | A2 |     | B1 | B2 |     | A1.B1+A2.B3 | A1.B2+A2.B4 |
+——--+—-——+ dot +——--+—-——+  =  +——————————---+—-—————————-—+
| A3 | A4 |     | B3 | B4 |     | A3.B1+A4.B3 | A3.B2+A4.B4 |
+——--+—-——+     +——--+—-——+     +——————————---+—-—————————-—+

对于您询问的 4x4 情况,这种 2x2 块结构意味着

A
B
中的每个子矩阵都是 2x2。

如果您选择通过外积展开来执行每个子矩阵乘积,而不是通过您已经了解的一组内积,则可以按如下方式执行,以

A1.B1
为例:

A1.B1 = sum(outer(A1[:,1],B1[1,:]), outer(A1[:,2],B1[2,:])

这是

outer([a11 a21],[b11,b12]) + outer([a12 a22],[b21,b22])

| a11*b11 a11*b12 | + | a12*b21 a12*b22 | = | a11*b11+a12*b21 a11*b12+a12*b22 |
| a21*b11 a21*b12 |   | a22*b21 a22*b22 |   | a21*b11+a22*b21 a21*b12+a22*b22 |

确认结果 RHS 中的项与通过计算两个子矩阵的内积集获得的项相同应该很简单。

您对其他七个子矩阵乘积重复此过程,并累积结果以产生完整的矩阵乘法。

© www.soinside.com 2019 - 2024. All rights reserved.