samples/atiSamples/MatrixTranspose/MatrixTranspose.rkt
#lang racket
(require "../../../c.rkt"
         "../atiUtils/utils.rkt"
         ffi/unsafe
         ffi/cvector
         ffi/unsafe/cvector)

(define setupTime -1)
(define totalKernelTime -1)
(define devices #f)
(define context #f)
(define commandQueue #f)
(define program #f)
(define kernel #f)
(define width 64)
(define height 64)
(define blockSize 16)
(define input #f)
(define output #f)
(define verificationOutput #f)
(define inputBuffer #f)
(define outputBuffer #f)

(define (setupMatrixTranspose)
  (define inputSizeBytes (* width height (ctype-sizeof _cl_float)))
  (set! input (malloc inputSizeBytes 'raw))
  (fill-random:_cl_float input (* width height) 255)
  (set! output (malloc inputSizeBytes 'raw))
  (set! verificationOutput (malloc inputSizeBytes 'raw)))

(define (setupCL)
  (define size (* width height (ctype-sizeof _cl_float)))
  (set!-values (devices context commandQueue program) (init-cl "MatrixTranspose_Kernels.cl" #:queueProperties 'CL_QUEUE_PROFILING_ENABLE))
  (set! inputBuffer (clCreateBuffer context '(CL_MEM_READ_ONLY CL_MEM_USE_HOST_PTR) size input))
  (set! outputBuffer (clCreateBuffer context '(CL_MEM_WRITE_ONLY CL_MEM_USE_HOST_PTR) size output))
  (set! kernel (clCreateKernel program #"matrixTranspose")))

(define (runCLKernels)
  (define globalThreads (vector width height))
  (define localThreads (vector blockSize blockSize))
  (define threadNum (optimum-threads kernel (cvector-ref devices 0) (* blockSize blockSize)))
  (when (> (* blockSize blockSize) threadNum)
    (set! blockSize 4)
    (set! localThreads (vector blockSize blockSize)))
  (clSetKernelArg:_cl_mem kernel 0 outputBuffer)
  (clSetKernelArg:_cl_mem kernel 1 inputBuffer)
  (clSetKernelArg:local kernel 2 (* blockSize blockSize (ctype-sizeof _cl_float)))
  (clSetKernelArg:_cl_int kernel 3 width)
  (clSetKernelArg:_cl_int kernel 4 height)
  (clSetKernelArg:_cl_int kernel 5 blockSize)
  (define event (clEnqueueNDRangeKernel commandQueue kernel 2 globalThreads localThreads (make-vector 0)))
  (clWaitForEvents (vector event))
  (clReleaseEvent event)
  (clEnqueueReadBuffer commandQueue outputBuffer 'CL_TRUE 0 (* width height (ctype-sizeof _cl_float)) output (make-vector 0)))

(define (matrixTransposeCPUReference)
  (for ([j (in-range width)])
    (for ([i (in-range height)])
      (ptr-set! verificationOutput _cl_float (+ j (* i height))
                (ptr-ref input _cl_float (+ i (* j width))))))
  (compare output verificationOutput (* width height)))

(define (setup)
  (setupMatrixTranspose)
  (set! setupTime (time-real setupCL)))

(define (run)
  (set! totalKernelTime (time-real runCLKernels)))

(define (verify-results)
  (define verified (matrixTransposeCPUReference))
  (printf "~n~a~n" (if verified "Passed" "Failed")))

(define (print-stats)
  (printf "~nWxH: ~a, Setup Time: ~a, Kernel Time: ~a, Total Time: ~a~n"
          "64x64"
          (real->decimal-string setupTime 3) 
          (real->decimal-string totalKernelTime 3)
          (real->decimal-string (+ setupTime totalKernelTime) 3)))

(define (cleanup)
  (clReleaseKernel kernel)
  (clReleaseProgram program)
  (clReleaseMemObject inputBuffer)
  (clReleaseMemObject outputBuffer)
  (clReleaseCommandQueue commandQueue)
  (clReleaseContext context)
  (free input)
  (free output)
  (free verificationOutput))

(setup)
(run)
(verify-results)
(cleanup)
(print-stats)