samples/atiSamples/MatrixMultiplication/MatrixMultiplication.rkt
#lang racket
(require "../../../c.rkt"
         "../atiUtils/utils.rkt"
         racket/runtime-path
         ffi/unsafe
         ffi/cvector
         ffi/unsafe/cvector)

(define setupTime -1)
(define totalKernelTime -1)
(define devices #f)
(define context #f)
(define commandQueue #f)
(define program #f)
(define kernel #f)
(define width 64)
(define height 64)
(define blockSize 8)
(define input0 #f)
(define input1 #f)
(define output #f)
(define verificationOutput #f)
(define inputBuffer0 #f)
(define inputBuffer1 #f)
(define outputBuffer #f)

(define (setupMatrixMultiplication)
  (define inputSizeBytes (* width height (ctype-sizeof _cl_float)))
  (set! input0 (malloc inputSizeBytes 'raw))
  (set! input1 (malloc inputSizeBytes 'raw))
  (fill-random:_cl_float input0 (* width height))
  (fill-random:_cl_float input1 (* width height))
  (set! output (malloc inputSizeBytes 'raw))
  (set! verificationOutput (malloc inputSizeBytes 'raw))
  (memset verificationOutput 0 inputSizeBytes))

(define-runtime-path kernel-path "MatrixMultiplication_Kernels.cl")
(define (setupCL)
  (define size (* width height (ctype-sizeof _cl_float)))
  (set!-values (devices context commandQueue program) (init-cl kernel-path #:queueProperties 'CL_QUEUE_PROFILING_ENABLE))
  (set! inputBuffer0 (clCreateBuffer context 'CL_MEM_READ_ONLY size #f))
  (clEnqueueWriteBuffer commandQueue inputBuffer0 'CL_TRUE 0 size input0 (make-vector 0))
  (set! inputBuffer1 (clCreateBuffer context 'CL_MEM_READ_ONLY size #f))
  (clEnqueueWriteBuffer commandQueue inputBuffer1 'CL_TRUE 0 size input1 (make-vector 0))
  (set! outputBuffer (clCreateBuffer context 'CL_MEM_WRITE_ONLY size #f))
  (define s (clGetMemObjectInfo:generic outputBuffer 'CL_MEM_SIZE))
  (display s)
  (set! kernel (clCreateKernel program #"mmmKernel")))

(define (runCLKernels)
  (define globalThreads (vector (/ width 4) (/ height 4)))
  (define localThreads (vector blockSize blockSize))
  (define threadNum (optimum-threads kernel (cvector-ref devices 0) (* blockSize blockSize)))
  (when (> (* blockSize blockSize) threadNum)
    (set! blockSize 4)
    (set! localThreads (vector blockSize blockSize)))
  (clSetKernelArg:_cl_mem kernel 0 inputBuffer0)
  (clSetKernelArg:_cl_mem kernel 1 inputBuffer1)
  (clSetKernelArg:_cl_mem kernel 2 outputBuffer)
  (clSetKernelArg:_cl_int kernel 3 width)
  (clSetKernelArg:_cl_int kernel 4 width)
  (define event (clEnqueueNDRangeKernel commandQueue kernel 2 globalThreads localThreads (make-vector 0)))
  (clWaitForEvents (vector event))
  (clFlush commandQueue)
  (define startTime (clGetEventProfilingInfo:generic event 'CL_PROFILING_COMMAND_START))
  (define endTime (clGetEventProfilingInfo:generic event 'CL_PROFILING_COMMAND_END))
  (clReleaseEvent event)
  (define sec (* 1e-9 (- endTime startTime)))
  (printf "KernelTime (ms) : ~a~n" (* sec 1000))
  (define flops (* 2 width width))
  (define perf (* (/ flops sec) height 1e-9))
  (printf "GFlops achieved : ~a~n~n" perf)
  (clEnqueueReadBuffer commandQueue outputBuffer 'CL_TRUE 0 (* width height (ctype-sizeof _cl_float)) output (make-vector 0)))

(define (matrixMultiplicationCPUReference)
  (for ([i (in-range width)])
    (for ([j (in-range width)])
      (for ([k (in-range width)])
        (ptr-set! verificationOutput _cl_float (+ j (* i width))
                  (+ (ptr-ref verificationOutput _cl_float (+ j (* i width)))
                     (* (ptr-ref input0 _cl_float (+ k (* i width)))
                        (ptr-ref input1 _cl_float (+ j (* k width)))))))))
  (compare output verificationOutput (* width height)))

(define (setup)
  (setupMatrixMultiplication)
  (set! setupTime (time-real setupCL)))

(define (run)
  (set! totalKernelTime (time-real runCLKernels)))

(define (verify-results)
  (define verified (matrixMultiplicationCPUReference))
  (printf "~n~a~n" (if verified "Passed" "Failed")))

(define (print-stats)
  (define actualSamples (* width height 4))
  (printf "~nMatrixA: ~a, MatrixB: ~a, Setup Time: ~a, Kernel Time: ~a, Total Time: ~a~n"
          "64x64"
          "64x64"
          (real->decimal-string setupTime 3) 
          (real->decimal-string totalKernelTime 3)
          (real->decimal-string (+ setupTime totalKernelTime) 3)))

(define (cleanup)
  (clReleaseKernel kernel)
  (clReleaseProgram program)
  (clReleaseMemObject inputBuffer0)
  (clReleaseMemObject inputBuffer1)
  (clReleaseMemObject outputBuffer)
  (clReleaseCommandQueue commandQueue)
  (clReleaseContext context)
  (free input0)
  (free input1)
  (free output)
  (free verificationOutput))

(setup)
(run)
(verify-results)
(cleanup)
(print-stats)