▶ 各种稀疏矩阵数据结构下 y(n,1) = A(n,m) * x(m,1) 的实现,CPU版本
● MAT 乘法
1 int dotCPU(const MAT *a, const MAT *x, MAT *y) 2 { 3 checkNULL(a); checkNULL(x); checkNULL(y); 4 if (a->col != x->row) 5 { 6 printf("dotMATCPU dimension mismatch! "); 7 return 1; 8 } 9 10 y->row = a->row; 11 y->col = x->col; 12 for (int i = 0; i < a->row; i++) 13 { 14 format sum = 0; 15 for (int j = 0; j < a->col; j++) 16 sum += a->data[i * a->col + j] * x->data[j]; 17 y->data[i] = sum; 18 } 19 COUNT_MAT(y); 20 return 0; 21 }
● CSR 乘法
1 int dotCPU(const CSR *a, const MAT *x, MAT *y) 2 { 3 checkNULL(a); checkNULL(x); checkNULL(y); 4 if (a->col != x->row) 5 { 6 printf("dotCSRCPU dimension mismatch! "); 7 return 1; 8 } 9 10 y->row = a->row; 11 y->col = x->col; 12 for (int i = 0; i < a->row; i++) // i 遍历 ptr,j 遍历行内数据,A 中为 0 的元素不参加乘法 13 { 14 format sum = 0; 15 for (int j = a->ptr[i]; j < a->ptr[i + 1]; j++) 16 sum += a->data[j] * x->data[a->index[j]]; 17 y->data[i] = sum; 18 } 19 COUNT_MAT(y); 20 return 0; 21 }
● ELL 乘法
1 int dotCPU(const ELL *a, const MAT *x, MAT *y) // CPU ELL乘法 2 { 3 checkNULL(a); checkNULL(x); checkNULL(y); 4 if (a->colOrigin != x->row) 5 { 6 printf("dotELLCPU dimension mismatch! "); 7 return 1; 8 } 9 10 y->row = a->col; 11 y->col = x->col; 12 for (int i = 0; i<a->col; i++) 13 { 14 format sum = 0; 15 for (int j = 0; j < a->row; j++) 16 { 17 int temp = a->index[j * a->col + i]; 18 if (temp < 0) // 跳过无效元素 19 continue; 20 sum += a->data[j * a->col + i] * x->data[temp]; 21 } 22 y->data[i] = sum; 23 } 24 COUNT_MAT(y); 25 return 0; 26 }
● COO 乘法
1 int dotCPU(const COO *a, const MAT *x, MAT *y) 2 { 3 checkNULL(a); checkNULL(x); checkNULL(y); 4 if (a->col != x->row) 5 { 6 printf("dotCOOCPU null! "); 7 return 1; 8 } 9 10 y->row = a->row; 11 y->col = x->col; 12 for (int i = 0; i<a->count; i++) 13 y->data[a->rowIndex[i]] += a->data[i] * x->data[a->colIndex[i]]; 14 COUNT_MAT(y); 15 return 0; 16 }
● DIA 乘法
1 int dotCPU(const DIA *a, const MAT *x, MAT *y) 2 { 3 checkNULL(a); checkNULL(x); checkNULL(y); 4 if (a->colOrigin != x->row) 5 { 6 printf("dotDIACPU null! "); 7 return 1; 8 } 9 y->row = a->row; 10 y->col = x->col; 11 int * inverseIndex = (int *)malloc(sizeof(int) * a->col); 12 for (int i = 0, j = 0; i < a->row + a->col - 1; i++) 13 { 14 if (a->index[i] == 1) 15 { 16 inverseIndex[j] = i; 17 j++; 18 } 19 } 20 for (int i = 0; i < a->row; i++) 21 { 22 format sum = 0; 23 for (int j = 0; j < a->col; j++) 24 { 25 if (i < a->row - 1 - inverseIndex[j] || i > inverseIndex[a->col - 1] - inverseIndex[j]) 26 continue; 27 sum += a->data[i * a->col + j] * x->data[i + inverseIndex[j] - a->row + 1]; 28 } 29 y->data[i] = sum; 30 } 31 COUNT_MAT(y); 32 free(inverseIndex); 33 return 0; 34 }