#include "fmm.h"
// Slow fmm :)
void basefmm(int n, int* m1, int* m2, int* result, int* col1) {
for (int j = 0; j < n; j++) {
for (int k = 0; k < n; k++) {
col1[k] = m2[k * n + j];
}
for (int i = 0; i < n; i++) {
int x = i * n;
result[i * n + j] = 0;
for (int k = 0; k < n; k ++) {
result[x + j] += m1[x + k] * col1[k];
}
}
}
}
int* add(int n, int* m1, int* m2) {
int* res = malloc(n * n * sizeof(int));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
res[i * n + j] = m1[i * n + j] + m2[i * n + j];
}
}
return res;
}
int* sub(int n, int* m1, int* m2) {
int* res = malloc(n * n * sizeof(int));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
res[i * n + j] = m1[i * n + j] - m2[i * n + j];
}
}
return res;
}
void fmm(int n, int* m1, int* m2, int* result) {
int* col1 = malloc(64 * sizeof(int));
if (n < 64) {
basefmm(n, m1, m2, result, col1);
return;
}
n = n / 2;
int nsquare = n * n;
int size = nsquare * sizeof(int);
int* A = malloc(size);
int* B = malloc(size);
int* C = malloc(size);
int* D = malloc(size);
int* E = malloc(size);
int* F = malloc(size);
int* G = malloc(size);
int* H = malloc(size);
for (int i = 0; i < n; i++) {
for (int j = 0;j < n; j++) {
int x = i * n;
A[x + j] = m1[i * n * 2 + j];
E[x + j] = m2[i * n * 2 + j];
C[x + j] = m1[(i + n) * 2 * n + j];
G[x + j] = m2[(i + n) * 2 * n + j];
B[x + j] = m1[i * n * 2 + j + n];
F[x + j] = m2[i * n * 2 + j + n];
D[x + j] = m1[(i + n) * n * 2 + j + n];
H[x + j] = m2[(i + n) * n * 2 + j + n];
}
}
int* p1 = malloc(size);
int* p2 = malloc(size);
int* p3 = malloc(size);
int* p4 = malloc(size);
int* p5 = malloc(size);
int* p6 = malloc(size);
int* p7 = malloc(size);
fmm(n, add(n, A, D), add(n, E, H), p1);
fmm(n, D, sub(n, G, E), p2);
fmm(n, add(n, A, B), H, p3);
fmm(n, sub(n,B,D),add(n,G,H),p4);
fmm(n, A, sub(n, F, H), p5);
fmm(n, add(n, C, D), E, p6);
fmm(n, sub(n, A, C), add(n, E, F), p7);
int* C11 = A;
int* C12 = B;
int* C21 = C;
int* C22 = D;
C11 = add(n, p1, p2);
C11 = sub(n, C11, p3);
C11 = add(n, C11, p4);
C12 = add(n, p5, p3);
C21 = add(n, p6, p2);
C22 = add(n, p5, p1);
C22 = sub(n, C22, p6);
C22 = sub(n, C22, p7);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
int x = i * n;
result[x * 2 + j] = C11[x + j];
result[(i + n) * n * 2 + j] = C21[x + j];
result[x * 2 + j + n] = C12[x + j];
result[(i + n) * n * 2 + j + n] = C22[x + j];
}
}
}
这是我的代码,在我的作业中,我被要求尽可能快地乘以 2 个矩阵(n 是 2 的幂),现在代码在我的计算机上以 400 毫秒运行,有什么办法让它运行得更快(在这段代码 aij=a[i*n+J]
我正在使用strassen算法并尝试优化基础,我不知道还能做什么来让这段代码运行得更快,所以如果你能提供一些提示,那将会有很大帮助
这段代码值得更好的内存管理。小改进:减少
malloc()
呼叫。
// int* A = malloc(size);
// ...
// int* H = malloc(size);
size_t size = nsquare * sizeof(int);
int* A = malloc(size * 8);
int* B = A + nsquare;
int* C = B + nsquare;
...
int* H = G + nsquare;
// Be sure to only `free(A)`
尽可能改进 O()。
介绍其他潜在想法: