samples/atiSamples/BitonicSort/BitonicSort.rkt
#lang racket
(require "../../../c.rkt"
         "../atiUtils/utils.rkt"
         ffi/unsafe
         ffi/cvector
         ffi/unsafe/cvector)

(define setupTime -1)
(define totalKernelTime -1)
(define devices #f)
(define context #f)
(define commandQueue #f)
(define program #f)
(define length 1024)
(define input #f)
(define inputBuffer #f)
(define kernel #f)
(define sortAscending 1)

(define (setupBitonicSort)
  (define inputSizeBytes (* length (ctype-sizeof _cl_uint)))
  (set! input (malloc inputSizeBytes 'raw))
  (ptr-set! input _cl_uint 0 0)
  (fill-random:_cl_uint input length)
  (print-array "Input" input length))

(define (setupCL)
  (set!-values (devices context commandQueue program) (init-cl "BitonicSort_Kernels.cl" #:queueProperties 'CL_QUEUE_PROFILING_ENABLE))
  (set! inputBuffer (clCreateBuffer context '(CL_MEM_READ_WRITE CL_MEM_USE_HOST_PTR) (* (ctype-sizeof _cl_uint) length) input))
  (set! kernel (clCreateKernel program #"bitonicSort")))

(define (runCLKernels)
  (define device (cvector-ref devices 0))
  (define globalThreads (/ length 2))
  (define localThreads (optimum-threads kernel device 256))
  (define numStages 0)
  (let loop ([temp length])
    (when (> temp 1)
      (set! numStages (add1 numStages))
      (loop (/ temp 2))))
  (clSetKernelArg:_cl_mem kernel 0 inputBuffer)
  (clSetKernelArg:_cl_uint kernel 3 length)
  (clSetKernelArg:_cl_uint kernel 4 sortAscending)
  (for ([stage (in-range numStages)])
    (clSetKernelArg:_cl_uint kernel 1 stage)
    (for ([pass (in-range (add1 stage))])
      (clSetKernelArg:_cl_uint kernel 2 pass)
      (define event (clEnqueueNDRangeKernel commandQueue kernel 1 (vector globalThreads) (vector localThreads) (make-vector 0)))
      (clWaitForEvents (vector event))
      (clReleaseEvent event)))
  (clEnqueueReadBuffer commandQueue inputBuffer 'CL_TRUE 0 (* length (ctype-sizeof _cl_uint)) input (make-vector 0)))

(define (swapIfFirstIsGreater in i j)
  (define val1 (ptr-ref in _cl_uint i))
  (define val2 (ptr-ref in _cl_uint j))
  (when (> val1 val2)
    (ptr-set! in _cl_uint i val2)
    (ptr-set! in _cl_uint j val1)))

(define (bitonicSortCPUReference verificationInput)
  (define halfLength (/ length 2))
  (let outer ([i 2])
    (when (<= i length)
      (let inner ([j i])
        (when (> j 1)
          (let ([increasing (if (= sortAscending 0) #f #t)]
                [halfJ (/ j 2)])
            (for ([k (in-range 0 length j)])
              (let ([kPlusHalfJ (+ k halfJ)])
                (when (< i length)
                  (when (and 
                         (not (= halfLength k)) 
                         (or (= k i) (= (remainder k i) 0)))
                    (set! increasing (not increasing))))
                (for ([l (in-range k kPlusHalfJ)])
                  (if increasing
                      (swapIfFirstIsGreater verificationInput l (+ l halfJ))
                      (swapIfFirstIsGreater verificationInput (+ l halfJ) l)))))
            (inner (/ j 2)))))
      (outer (* i 2))))
  (define passed #t)
  (for ([i (in-range length)])
    (define val1 (ptr-ref input _cl_uint i))
    (define val2 (ptr-ref verificationInput _cl_uint i))
    (when (not (= val1 val2)) (set! passed #f)))
  passed)


(define (setup)
  (setupBitonicSort)
  (set! setupTime (time-real setupCL)))

(define (run)
  (set! totalKernelTime (time-real runCLKernels))
  (print-array "Output" input length length))

(define (verify-results)
  (define verificationInput (malloc (* length (ctype-sizeof _cl_uint)) 'raw))
  (memcpy verificationInput input (* length (ctype-sizeof _cl_uint)))
  (define verified (bitonicSortCPUReference verificationInput))
  (printf "~n~a~n" (if verified "Passed" "Failed"))
  (free verificationInput))

(define (cleanup)
  (clReleaseKernel kernel)
  (clReleaseProgram program)
  (clReleaseMemObject inputBuffer)
  (clReleaseCommandQueue commandQueue)
  (clReleaseContext context)
  (free input))

(define (print-stats)
  (printf "~nLength: ~a, Setup Time: ~a, Kernel Time: ~a, Total Time: ~a~n"
          length 
          (real->decimal-string setupTime 3) 
          (real->decimal-string totalKernelTime 3)
          (real->decimal-string (+ setupTime totalKernelTime) 3)))

(setup)
(run)
(verify-results)
(cleanup)
(print-stats)