#lang racket (require "../../../c.rkt" "../atiUtils/utils.rkt" racket/runtime-path ffi/unsafe ffi/cvector ffi/unsafe/cvector) (define setupTime -1) (define totalKernelTime -1) (define devices #f) (define context #f) (define commandQueue #f) (define program #f) (define kernel #f) (define width 64) (define height 64) (define blockSize 8) (define input0 #f) (define input1 #f) (define output #f) (define verificationOutput #f) (define inputBuffer0 #f) (define inputBuffer1 #f) (define outputBuffer #f) (define (setupMatrixMultiplication) (define inputSizeBytes (* width height (ctype-sizeof _cl_float))) (set! input0 (malloc inputSizeBytes 'raw)) (set! input1 (malloc inputSizeBytes 'raw)) (fill-random:_cl_float input0 (* width height)) (fill-random:_cl_float input1 (* width height)) (set! output (malloc inputSizeBytes 'raw)) (set! verificationOutput (malloc inputSizeBytes 'raw)) (memset verificationOutput 0 inputSizeBytes)) (define-runtime-path kernel-path "MatrixMultiplication_Kernels.cl") (define (setupCL) (define size (* width height (ctype-sizeof _cl_float))) (set!-values (devices context commandQueue program) (init-cl kernel-path #:queueProperties 'CL_QUEUE_PROFILING_ENABLE)) (set! inputBuffer0 (clCreateBuffer context 'CL_MEM_READ_ONLY size #f)) (clEnqueueWriteBuffer commandQueue inputBuffer0 'CL_TRUE 0 size input0 (make-vector 0)) (set! inputBuffer1 (clCreateBuffer context 'CL_MEM_READ_ONLY size #f)) (clEnqueueWriteBuffer commandQueue inputBuffer1 'CL_TRUE 0 size input1 (make-vector 0)) (set! outputBuffer (clCreateBuffer context 'CL_MEM_WRITE_ONLY size #f)) (define s (clGetMemObjectInfo:generic outputBuffer 'CL_MEM_SIZE)) (display s) (set! kernel (clCreateKernel program #"mmmKernel"))) (define (runCLKernels) (define globalThreads (vector (/ width 4) (/ height 4))) (define localThreads (vector blockSize blockSize)) (define threadNum (optimum-threads kernel (cvector-ref devices 0) (* blockSize blockSize))) (when (> (* blockSize blockSize) threadNum) (set! blockSize 4) (set! localThreads (vector blockSize blockSize))) (clSetKernelArg:_cl_mem kernel 0 inputBuffer0) (clSetKernelArg:_cl_mem kernel 1 inputBuffer1) (clSetKernelArg:_cl_mem kernel 2 outputBuffer) (clSetKernelArg:_cl_int kernel 3 width) (clSetKernelArg:_cl_int kernel 4 width) (define event (clEnqueueNDRangeKernel commandQueue kernel 2 globalThreads localThreads (make-vector 0))) (clWaitForEvents (vector event)) (clFlush commandQueue) (define startTime (clGetEventProfilingInfo:generic event 'CL_PROFILING_COMMAND_START)) (define endTime (clGetEventProfilingInfo:generic event 'CL_PROFILING_COMMAND_END)) (clReleaseEvent event) (define sec (* 1e-9 (- endTime startTime))) (printf "KernelTime (ms) : ~a~n" (* sec 1000)) (define flops (* 2 width width)) (define perf (* (/ flops sec) height 1e-9)) (printf "GFlops achieved : ~a~n~n" perf) (clEnqueueReadBuffer commandQueue outputBuffer 'CL_TRUE 0 (* width height (ctype-sizeof _cl_float)) output (make-vector 0))) (define (matrixMultiplicationCPUReference) (for ([i (in-range width)]) (for ([j (in-range width)]) (for ([k (in-range width)]) (ptr-set! verificationOutput _cl_float (+ j (* i width)) (+ (ptr-ref verificationOutput _cl_float (+ j (* i width))) (* (ptr-ref input0 _cl_float (+ k (* i width))) (ptr-ref input1 _cl_float (+ j (* k width))))))))) (compare output verificationOutput (* width height))) (define (setup) (setupMatrixMultiplication) (set! setupTime (time-real setupCL))) (define (run) (set! totalKernelTime (time-real runCLKernels))) (define (verify-results) (define verified (matrixMultiplicationCPUReference)) (printf "~n~a~n" (if verified "Passed" "Failed"))) (define (print-stats) (define actualSamples (* width height 4)) (printf "~nMatrixA: ~a, MatrixB: ~a, Setup Time: ~a, Kernel Time: ~a, Total Time: ~a~n" "64x64" "64x64" (real->decimal-string setupTime 3) (real->decimal-string totalKernelTime 3) (real->decimal-string (+ setupTime totalKernelTime) 3))) (define (cleanup) (clReleaseKernel kernel) (clReleaseProgram program) (clReleaseMemObject inputBuffer0) (clReleaseMemObject inputBuffer1) (clReleaseMemObject outputBuffer) (clReleaseCommandQueue commandQueue) (clReleaseContext context) (free input0) (free input1) (free output) (free verificationOutput)) (setup) (run) (verify-results) (cleanup) (print-stats)