• cublas 矩阵相乘API详解



    #include "cuda_runtime.h"
    #include "device_launch_parameters.h"

    #include <stdio.h>
    #include <stdlib.h>
    #include "cublas_v2.h"

    void multiCPU(float *c, float *a, float *b, unsigned int aH, unsigned int aW, unsigned int bH, unsigned int bW)
    {
      printf(" ");
      printf("matrix A<%2d,%2d> = ",aH,aW);
      for(int y=0; y<aH; ++y)
      {
        for(int x =0; x<aW; ++x)
        {
          int index = y*aW + x;
          printf("%8.1f",a[index]);
        }
        printf(" ");
      }
      printf(" ");

      printf("matrix B<%2d,%2d> = ",bH,bW);
      for(int y=0; y<bH; ++y)
      {
        for(int x =0; x<bW; ++x)
        {
          int index = y*bW + x;
          printf("%8.1f",b[index]);
        }
        printf(" ");
      }
      printf(" ");

      printf("matrix A*B<%2d,%2d> = ",aH,bW);
      for(int y=0; y<aH; ++y)
      {
        for(int x =0; x<bW; ++x)
        {
          int index = y*bW + x;
          c[index] = 0.0f;
          for(int i=0; i<aW; ++i)
          {
            c[index] += a[y*aW+i]*b[i*bW + x];
          }
          printf("%8.1f",c[index]);
        }
        printf(" ");
       }
       printf(" ");

    }

    void trans(float *a, unsigned int aH, unsigned int aW )
    {
      float* tr = (float*)malloc(sizeof(float)*aH*aW);
      int count = 0;
      for(int x = 0; x <aW; ++x)
      {  
        for(int y=0; y<aH; ++y)
        {
          int index = y*aW + x;
          tr[count] = a[index];
          count++;
        }
      }

      for(int i = 0; i<count;i++)
      {
        a[i] = tr[i];
      }
      free(tr);


      for(int y=0; y < aW; ++y)
      {
        for(int x =0; x < aH; ++x)
        {
          int index = y*aH + x;
          printf("%8.1f",a[index]);
        }
        printf(" ");
      }
      printf(" ");
    }

    int main()
    {
      const int aHight = 3, aWidth =5;
      const int bHight = 5, bWidth =4;
      float a[aHight*aWidth] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
      float b[bHight*bWidth] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
      float c[aHight*bWidth] = { 0 };
      float c_cuBlas[aHight*bWidth] = { 0 };

      multiCPU(c, a, b, aHight,aWidth, bHight, bWidth);

      float *gpu_a = 0;
      float *gpu_b = 0;
      float *gpu_c = 0;

      cudaError_t cudaStatus;

      cudaStatus = cudaSetDevice(0);
      if (cudaStatus != cudaSuccess) {
      fprintf(stderr, "cudaSetDevice failed! Do you have a CUDA-capable GPU installed?");
      goto Error;
      }

      cudaStatus = cudaMalloc((void**)&gpu_a,aHight*aWidth*sizeof(float));
      if (cudaStatus != cudaSuccess) {
      fprintf(stderr, "cudaMalloc failed!");
      goto Error;
      }

      cudaStatus = cudaMalloc((void**)&gpu_b,bHight*bWidth*sizeof(float));
      if (cudaStatus != cudaSuccess) {
      fprintf(stderr, "cudaMalloc failed!");
      goto Error;
      }

      cudaStatus = cudaMalloc((void**)&gpu_c,aHight*bWidth*sizeof(float));
      if (cudaStatus != cudaSuccess) {
      fprintf(stderr, "cudaMalloc failed!");
      goto Error;
      }

      cudaStatus = cudaMemcpy(gpu_a, a, aHight*aWidth*sizeof(float), cudaMemcpyHostToDevice);
      if (cudaStatus != cudaSuccess) {
      fprintf(stderr, "cudaMemcpy failed!");
      goto Error;
      }

      cudaStatus = cudaMemcpy(gpu_b, b,bHight*bWidth*sizeof(float), cudaMemcpyHostToDevice);
      if (cudaStatus != cudaSuccess) {
      fprintf(stderr, "cudaMemcpy failed!");
      goto Error;
      }

      //printf("Computing result using CUBLAS... ");

      cublasHandle_t handle;
      cublasStatus_t ret;
      ret = cublasCreate(&handle);
      if (ret != CUBLAS_STATUS_SUCCESS){
      printf("cublasCreate returned error code %d, line(%d) ", ret, __LINE__);
      goto Error;
      }

      const float alpha = 1.0f;
      const float beta = 0.0f;

      ret = cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_T, aHight, bWidth, aWidth, &alpha, gpu_a, aWidth, gpu_b, bWidth, &beta, gpu_c, aHight);

      cudaStatus = cudaMemcpy(c_cuBlas, gpu_c, aHight*bWidth*sizeof(float), cudaMemcpyDeviceToHost);
      if (cudaStatus != cudaSuccess) {
      fprintf(stderr, "cudaMemcpy failed!");
      goto Error;
      }
      cublasDestroy(handle);
      /*
      trans(b,bHight,bWidth);
      trans(a,aHight,aWidth);
      multiCPU(c, b, a, bWidth, bHight, aWidth, aHight);
      */

      printf(" cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_T, aHight, bWidth, aWidth, &alpha, gpu_a, aWidth, gpu_b, bWidth, &beta, gpu_c, aHight); ");
      printf("c_cuBlas<%2d,%2d> = ",bWidth,aHight);
      for(int y=0; y<bWidth; ++y)
      {
        for(int x=0; x<aHight ;++x)
        {
          int index = y*aHight + x;
          printf("%8.1f",c_cuBlas[index]);
        }

        printf(" ");
      }
      printf(" ");

      printf("After trans: c_cuBlas<%2d,%2d> = ",aHight,bWidth);
      trans(c_cuBlas,bWidth,aHight);
      printf(" ");

    Error:
      cudaFree(gpu_a);
      cudaFree(gpu_b);
      cudaFree(gpu_c);
      return 0;
    }

     CUBLAS_OP_T是转置标志,转置的意思是把你原来按行存储的方式改为按照他需要的列方式存储然后进行运算,并不是改变你要运算的维度。

    那什么时候用CUBLAS_OP_N呢??

    ret = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, aHight, bWidth, aWidth, &alpha, gpu_a, aHight, gpu_b, bHight, &beta, gpu_c, aHight);

    如果这样用,CUBLAS_OP_N是不转置标志,就是默认你的存储方式就是按照列来存储的!

    ps:这里aHight = cHight 应该不难理解,因为在cublas里面矩阵都是按列方式存储,所以矩阵C的leading dimension自然是cHight,即aHight。

    按照我的理解leading dimension就是向量的维度n

    用CUBLAS_OP_T 还是CUBLAS_OP_N就是要看你用存储的矩阵是不是符合cublas想要的方式了,如果你是按照行存储的又想做矩阵相乘,那就要用CUBLAS_OP_T。

     
    20140824++++++++++++++++++++++++++++++++++++++++++++++++++++++

    两个按列存储的矩阵:c=a*bT

    如何通过CUBLAS的乘法接口完成?

    将b矩阵设置为按行存储,并将b的列设为bHight,行设置为bWidth,主维度为bHight,即完成了b的转置

    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, aHight, bHight, aWidth, &alpha, gpu_a, aHight, gpu_b, bHight, &beta, gpu_c, aHight);

    输出结果为按排列:

  • 相关阅读:
    MVC Form
    The way to learn english
    Test FastThree
    C#中Trim()、TrimStart()、TrimEnd()的用法
    c# Dictionary 简介
    visual studio快捷键大全
    ASP.NET MVC 中 ActionResult
    MVC4中使用 Ninject
    MVC Chapter 12 Overview of MVC Projects
    ASP.NET Razor
  • 原文地址:https://www.cnblogs.com/huangshan/p/3917153.html
Copyright © 2020-2023  润新知