我的老师写道:
为矩阵乘积实现 CUDA 内核作为外积向量。在此版本中,每个 K 个线程块通过实现矩阵外积公式来计算大小为 KxK 的结果矩阵的一个正方形。内核使用共享内存来存储矩阵 A 和矩阵 B 中相应的列向量,并存储结果数组的相应片段。
据我从本文中了解到,他希望我使用向量外积进行矩阵乘法,其中还包含平铺。这就是我想出这个的原因。
假设我想使用
kxk=2x2
块来乘以以下矩阵:
A = B = [[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]
乘法结果将是:
平铺乘法是这样工作的吗?
或者,我错过了什么吗?
===========================================================================================
k r i j a[i][r] b[r][j] a[i][r]*b[r][j] c[i][j]
===========================================================================================
2 a[0][0]=1 b[0][0]=1 1*1=1 c[0][0]=1
a[0][0]=1 b[0][1]=2 1*2=2 c[0][1]=2
a[1][0]=5 b[0][0]=1 5*1=5 c[1][0]=5
a[1][0]=5 b[0][1]=2 5*2=10 c[1][1]=10
—------------------------------------------------------------------------------------------
a[0][1]=2 b[1][0]=5 2*5=10 c[0][0]=(1+10)=11
a[0][1]=2 b[1][1]=6 2*6=12 c[0][1]=(2+12)=14
a[1][1]=6 b[1][0]=5 6*5=30 c[1][0]=(5+30)=35
a[1][1]=6 b[1][1]=6 6*6=36 c[1][1]=(10+36)=46
—------------------------------------------------------------------------------------------
2 a[0][2]=3 b[2][0]=9 3*9=27 c[0][0]=(11+27)=38
a[0][2]=3 b[2][1]=10 3*10=30 c[0][1]=(14+30)=44
a[1][2]=7 b[2][0]=9 7*9=63 c[1][0]=(35+63)=98
a[1][2]=7 b[2][1]=10 7*10=70 c[1][1]=(46+70)=116
—------------------------------------------------------------------------------------------
a[0][3]=4 b[3][0]=13 4*13=52 c[0][0]=(38+52)=90
a[0][3]=4 b[3][1]=14 4*14=56 c[0][1]=(44+56)=100
a[1][3]=8 b[3][0]=13 8*13=104 c[1][0]=(98+104)=202
a[1][3]=8 b[3][1]=14 8*14=112 c[1][1]=(116+112)=228
===========================================================================================
2 a[0][0]=1 b[0][2]=3 1*3=3 c[0][2]=3
a[0][0]=1 b[0][3]=4 1*4=4 c[0][3]=4
a[1][0]=5 b[0][2]=3 5*3=15 c[1][2]=15
a[1][0]=5 b[0][3]=4 5*4=20 c[1][3]=20
—------------------------------------------------------------------------------------------
a[0][1]=2 b[1][2]=7 2*7=14 c[0][2]=(3+14)=17
a[0][1]=2 b[1][3]=8 2*8=16 c[0][3]=(4+16)=20
a[1][1]=6 b[1][2]=7 6*7=42 c[1][2]=(15+42)=57
a[1][1]=6 b[1][3]=8 6*8=48 c[1][3]=(20+48)=68
—------------------------------------------------------------------------------------------
2 a[0][2]=3 b[2][2]=11 3*11=33 c[0][2]=(17+33)=50
a[0][2]=3 b[2][3]=12 3*12=36 c[0][3]=(20+36)=56
a[1][2]=7 b[2][2]=11 7*11=77 c[1][2]=(57+77)=134
a[1][2]=7 b[2][3]=12 7*12=84 c[1][3]=(68+84)=152
—------------------------------------------------------------------------------------------
a[0][3]=4 b[3][2]=15 4*15=60 c[0][2]=(50+60)=110
a[0][3]=4 b[3][3]=16 4*16=64 c[0][3]=(56+64)=120
a[1][3]=8 b[3][2]=15 8*15=120 c[1][2]=(134+120)=254
a[1][3]=8 b[3][3]=16 8*16=128 c[1][3]=(152+128)=280
===========================================================================================
2 a[2][0]=9 b[0][0]=1 9*1=9 c[2][0]=(0+9)=9
a[2][0]=9 b[0][1]=2 9*2=18 c[2][1]=(0+18)=18
a[3][0]=13 b[0][0]=1 13*1=13 c[3][0]=(0+13)=13
a[3][0]=13 b[0][1]=2 13*2=26 c[3][1]=(0+26)=26
—------------------------------------------------------------------------------------------
a[2][1]=10 b[1][0]=5 10*5=50 c[2][0]=(9+50)=59
a[2][1]=10 b[1][1]=6 10*6=60 c[2][1]=(18+60)=78
a[3][1]=14 b[1][0]=5 14*5=70 c[3][0]=(13+70)=83
a[3][1]=14 b[1][1]=6 14*6=84 c[3][1]=(26+84)=110
—------------------------------------------------------------------------------------------
2 a[2][2]=11 b[2][0]=9 11*9=99 c[2][0]=(59+99)=158
a[2][2]=11 b[2][1]=10 11*10=110 c[2][1]=(78+110)=198
a[3][2]=15 b[2][0]=9 15*9=135 c[3][0]=(83+135)=218
a[3][2]=15 b[2][1]=10 15*10=150 c[3][1]=(110+150)=260
—------------------------------------------------------------------------------------------
2 a[2][3]=12 b[3][0]=13 12*13=156 c[2][0]=(158+156)=314
a[2][3]=12 b[3][1]=14 12*14=168 c[2][1]=(188+168)=356
a[3][3]=16 b[3][0]=13 16*13=208 c[3][0]=(218+208)=426
a[3][3]=16 b[3][1]=14 16*14=224 c[3][1]=(260+224)=484
===========================================================================================
2 a[2][0]=9 b[0][2]=3 9*3=27 c[2][2]=(0+27)=27
a[2][0]=9 b[0][3]=4 9*4=36 c[2][3]=(0+36)=36
a[3][0]=13 b[0][2]=3 13*3=39 c[3][2]=(0+39)=39
a[3][0]=13 b[0][3]=4 13*4=52 c[3][3]=(0+52)=52
—------------------------------------------------------------------------------------------
a[2][1]=10 b[1][2]=7 10*7=70 c[2][2]=(27+70)=97
a[2][1]=10 b[1][3]=8 10*8=80 c[2][3]=(36+80)=116
a[3][1]=14 b[1][2]=7 14*7=98 c[3][2]=(39+98)=137
a[3][1]=14 b[1][3]=8 14*8=112 c[3][3]=(52+112)=164
—------------------------------------------------------------------------------------------
2 a[2][2]=11 b[2][2]=11 11*11=121 c[2][2]=(97+121)=218
a[2][2]=11 b[2][3]=12 11*12=132 c[2][3]=(116+132)=248
a[3][2]=15 b[2][2]=11 15*11=165 c[3][2]=(137+165)=302
a[3][2]=15 b[2][3]=12 15*12=180 c[3][3]=(164+180)=344
—------------------------------------------------------------------------------------------
2 a[2][3]=12 b[3][2]=15 12*15=180 c[2][2]=(218+180)=398
a[2][3]=12 b[3][3]=16 12*16=192 c[2][3]=(248+192)=440
a[3][3]=16 b[3][2]=15 16*15=240 c[3][2]=(302+240)=542
a[3][3]=16 b[3][3]=16 16*16=256 c[3][3]=(344+256)=600
===========================================================================================
一般来说,“平铺”矩阵乘法意味着将矩阵点积重构为 block 矩阵的乘积,因此
A dot B
可以(作为一个示例)表示为:
+——--+—-——+ +——--+—-——+ +——————————---+—-—————————-—+
| A1 | A2 | | B1 | B2 | | A1.B1+A2.B3 | A1.B2+A2.B4 |
+——--+—-——+ dot +——--+—-——+ = +——————————---+—-—————————-—+
| A3 | A4 | | B3 | B4 | | A3.B1+A4.B3 | A3.B2+A4.B4 |
+——--+—-——+ +——--+—-——+ +——————————---+—-—————————-—+
对于您询问的 4x4 情况,这种 2x2 块结构意味着
A
和 B
中的每个子矩阵都是 2x2。
如果您选择通过外积展开来执行每个子矩阵乘积,而不是通过您已经了解的一组内积,则可以按如下方式执行,以
A1.B1
为例:
A1.B1 = sum(outer(A1[:,1],B1[1,:]), outer(A1[:,2],B1[2,:])
这是
outer([a11 a21],[b11,b12]) + outer([a12 a22],[b21,b22])
或
| a11*b11 a11*b12 | + | a12*b21 a12*b22 | = | a11*b11+a12*b21 a11*b12+a12*b22 |
| a21*b11 a21*b12 | | a22*b21 a22*b22 | | a21*b11+a22*b21 a21*b12+a22*b22 |
确认结果 RHS 中的项与通过计算两个子矩阵的内积集获得的项相同应该很简单。
您对其他七个子矩阵乘积重复此过程,并累积结果以产生完整的矩阵乘法。