Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <wb.h>
- #define wbCheck(stmt) \
- do { \
- cudaError_t err = stmt; \
- if (err != cudaSuccess) { \
- wbLog(ERROR, "Failed to run stmt ", #stmt); \
- wbLog(ERROR, "Got CUDA error ... ", cudaGetErrorString(err)); \
- return -1; \
- } \
- } while (0)
- __global__ void spmvJDSKernel(float *out, int *matColStart, int *matCols,
- int *matRowPerm, int *matRows,
- float *matData, float *vec, int dim) {
- //@@ invoke spmv kernel for jds format
- unsigned int row = blockIdx.x * blockDim.x + threadIdx.x;
- if (row < dim) {
- float sum = 0.0f;
- for (int i = 0; i < matRows[row]; i++) {
- unsigned int col = matCols[matColStart[i] + row];
- unsigned int val = matData[matColStart[i] + row];
- sum += val * vec[col];
- }
- // __syncthreads(); // think do I need it?
- out[matRowPerm[row]] = sum;
- }
- }
- static void spmvJDS(float *out, int *matColStart, int *matCols,
- int *matRowPerm, int *matRows, float *matData,
- float *vec, int dim) {
- unsigned int blockDim = 512;
- unsigned int gridDim = ceil(1.0*dim/blockDim);
- spmvJDSKernel<<<gridDim,blockDim>>>(out, matColStart, matCols, matRowPerm, matRows, matData, vec, dim);
- //@@ invoke spmv kernel for jds format
- }
- int main(int argc, char **argv) {
- wbArg_t args;
- int *hostCSRCols;
- int *hostCSRRows;
- float *hostCSRData;
- int *hostJDSColStart;
- int *hostJDSCols;
- int *hostJDSRowPerm;
- int *hostJDSRows;
- float *hostJDSData;
- float *hostVector;
- float *hostOutput;
- int *deviceJDSColStart;
- int *deviceJDSCols;
- int *deviceJDSRowPerm;
- int *deviceJDSRows;
- float *deviceJDSData;
- float *deviceVector;
- float *deviceOutput;
- int dim, ncols, nrows, ndata;
- int maxRowNNZ;
- args = wbArg_read(argc, argv);
- // Import data and create memory on host
- wbTime_start(Generic, "Importing data and creating memory on host");
- hostCSRCols = (int *)wbImport(wbArg_getInputFile(args, 0), &ncols, "Integer");
- hostCSRRows = (int *)wbImport(wbArg_getInputFile(args, 1), &nrows, "Integer");
- hostCSRData = (float *)wbImport(wbArg_getInputFile(args, 2), &ndata, "Real");
- hostVector = (float *)wbImport(wbArg_getInputFile(args, 3), &dim, "Real");
- hostOutput = (float *)malloc(sizeof(float) * dim);
- CSRToJDS(dim, hostCSRRows, hostCSRCols, hostCSRData, &hostJDSRowPerm, &hostJDSRows,
- &hostJDSColStart, &hostJDSCols, &hostJDSData); // converting from CSR to JDS
- maxRowNNZ = hostJDSRows[0];
- // Allocate GPU memory.
- cudaMalloc((void **)&deviceJDSColStart, sizeof(int) * maxRowNNZ);
- cudaMalloc((void **)&deviceJDSCols, sizeof(int) * ndata);
- cudaMalloc((void **)&deviceJDSRowPerm, sizeof(int) * dim);
- cudaMalloc((void **)&deviceJDSRows, sizeof(int) * dim);
- cudaMalloc((void **)&deviceJDSData, sizeof(float) * ndata);
- cudaMalloc((void **)&deviceVector, sizeof(float) * dim);
- cudaMalloc((void **)&deviceOutput, sizeof(float) * dim);
- // Copy input memory to the GPU.
- cudaMemcpy(deviceJDSColStart, hostJDSColStart, sizeof(int) * maxRowNNZ,
- cudaMemcpyHostToDevice);
- cudaMemcpy(deviceJDSCols, hostJDSCols, sizeof(int) * ndata, cudaMemcpyHostToDevice);
- cudaMemcpy(deviceJDSRowPerm, hostJDSRowPerm, sizeof(int) * dim, cudaMemcpyHostToDevice);
- cudaMemcpy(deviceJDSRows, hostJDSRows, sizeof(int) * dim, cudaMemcpyHostToDevice);
- cudaMemcpy(deviceJDSData, hostJDSData, sizeof(float) * ndata, cudaMemcpyHostToDevice);
- cudaMemcpy(deviceVector, hostVector, sizeof(float) * dim, cudaMemcpyHostToDevice);
- // Perform CUDA computation
- spmvJDS(deviceOutput, deviceJDSColStart, deviceJDSCols, deviceJDSRowPerm, deviceJDSRows,
- deviceJDSData, deviceVector, dim);
- cudaDeviceSynchronize();
- // Copy output memory to the CPU
- cudaMemcpy(hostOutput, deviceOutput, sizeof(float) * dim, cudaMemcpyDeviceToHost);
- // Free GPU Memory
- cudaFree(deviceVector);
- cudaFree(deviceOutput);
- cudaFree(deviceJDSColStart);
- cudaFree(deviceJDSCols);
- cudaFree(deviceJDSRowPerm);
- cudaFree(deviceJDSRows);
- cudaFree(deviceJDSData);
- wbSolution(args, hostOutput, dim);
- free(hostCSRCols);
- free(hostCSRRows);
- free(hostCSRData);
- free(hostVector);
- free(hostOutput);
- free(hostJDSColStart);
- free(hostJDSCols);
- free(hostJDSRowPerm);
- free(hostJDSRows);
- free(hostJDSData);
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement