samples/atiSamples/DwtHaar1D/DwtHaar1D.rkt
#lang racket
(require "../../../c.rkt"
         "../atiUtils/utils.rkt"
         ffi/unsafe
         ffi/cvector
         ffi/unsafe/cvector)

(define setupTime -1)
(define totalKernelTime -1)
(define devices #f)
(define context #f)
(define commandQueue #f)
(define program #f)
(define kernel #f)
(define signalLength (arithmetic-shift 1 10))
(define totalLevels 0)
(define kernelWorkGroupSize 0)
(define curSignalLength 0)
(define groupSize 0)
(define levelsDone 0)
(define inData #f)
(define dOutData #f)
(define dPartialOutData #f)
(define hOutData #f)
(define inDataBuf #f)
(define dOutDataBuf #f)
(define dPartialOutDataBuf #f)
(define maxLevelsOnDevice 0)

(define (calApproxFinalOnHost)
  (define result #t)
  (define size (* signalLength (ctype-sizeof _cl_float)))
  (define tempOutData (malloc size 'raw))
  (memcpy tempOutData inData size)
  (for ([i (in-range signalLength)])
    (ptr-set! tempOutData _cl_float i (/ (ptr-ref tempOutData _cl_float i)
                                         (sqrt signalLength))))
  (define length signalLength)
  (let loop ()
    (when (> length 1)
      (for ([i (in-range (/ length 2))])
        (define data0 (ptr-ref tempOutData _cl_float (* 2 i)))
        (define data1 (ptr-ref tempOutData _cl_float (add1 (* 2 i))))
        (ptr-set! hOutData _cl_float i (/ (+ data0 data1) (sqrt 2.0)))
        (ptr-set! hOutData _cl_float (+ i (/ length 2)) (/ (- data0 data1) (sqrt 2))))
      (memcpy tempOutData hOutData size)
      (set! length (arithmetic-shift length -1))
      (loop)))
  (let/ec break
    (for ([i (in-range signalLength)])
      (define val1 (ptr-ref dOutData _cl_float i))
      (define val2 (ptr-ref hOutData _cl_float i))
      (when (> (abs (- val1 val2)) 0.01)
        (set! result #f)
        (break))))
  (free tempOutData)
  result)

(define (getLevels length levels)
  (define val 0)
  (let/ec break
    (for ([i (in-range 24)])
      (when (= length (arithmetic-shift 1 i))
        (set! val i)
        (break))))
  val)

(define (setupDwtHaar1d)
  (define size (* signalLength (ctype-sizeof _cl_float)))
  (set! inData (malloc size 'raw))
  (for ([i (in-range signalLength)])
    (ptr-set! inData _cl_float i (exact->inexact (random 10))))
  (set! dOutData (malloc size 'raw))
  (memset dOutData 0 size)
  (set! dPartialOutData (malloc size 'raw))
  (memset dPartialOutData 0 size)
  (set! hOutData (malloc size 'raw))
  (memset hOutData 0 size))

(define (setupCL)
  (define size (* signalLength (ctype-sizeof _cl_float)))
  (set!-values (devices context commandQueue program) (init-cl "DwtHaar1D_Kernels.cl" #:queueProperties 'CL_QUEUE_PROFILING_ENABLE))
  (set! inDataBuf (clCreateBuffer context '(CL_MEM_READ_ONLY CL_MEM_USE_HOST_PTR) size inData))
  (set! dOutDataBuf (clCreateBuffer context 'CL_MEM_WRITE_ONLY size #f))
  (set! dPartialOutDataBuf (clCreateBuffer context 'CL_MEM_WRITE_ONLY size #f))
  (set! kernel (clCreateKernel program #"dwtHaar1D"))
  (set! kernelWorkGroupSize (optimum-threads kernel (cvector-ref devices 0) 256)))

(define (runDwtHaar1DKernel)
  (define globalThreads (vector (arithmetic-shift curSignalLength -1)))
  (define localThreads (vector groupSize))
  (clEnqueueWriteBuffer commandQueue inDataBuf 'CL_TRUE 0 (* curSignalLength (ctype-sizeof _cl_float)) inData (make-vector 0))
  (clSetKernelArg:_cl_mem kernel 0 inDataBuf)
  (clSetKernelArg:_cl_mem kernel 1 dOutDataBuf)
  (clSetKernelArg:_cl_mem kernel 2 dPartialOutDataBuf)
  (clSetKernelArg:local kernel 3 (* (vector-ref localThreads 0) 2 (ctype-sizeof _cl_float)))
  (clSetKernelArg:_cl_uint kernel 4 totalLevels)
  (clSetKernelArg:_cl_uint kernel 5 curSignalLength)
  (clSetKernelArg:_cl_uint kernel 6 levelsDone)
  (clSetKernelArg:_cl_uint kernel 7 maxLevelsOnDevice)
  (define event (clEnqueueNDRangeKernel commandQueue kernel 1 globalThreads localThreads (make-vector 0)))
  (clWaitForEvents (vector event))
  (clReleaseEvent event)
  (clEnqueueReadBuffer commandQueue dOutDataBuf 'CL_TRUE 0 (* signalLength (ctype-sizeof _cl_float)) dOutData (make-vector 0))
  (clEnqueueReadBuffer commandQueue dPartialOutDataBuf 'CL_TRUE 0 (* signalLength (ctype-sizeof _cl_float)) dPartialOutData (make-vector 0)))

(define (runCLKernels)
  (define levels (getLevels signalLength 0))
  (define actualLevels levels)
  (define curLevels 0)
  (set! maxLevelsOnDevice (add1 (inexact->exact (truncate (/ (log kernelWorkGroupSize) (log 2.0))))))
  (define temp (malloc (* signalLength (ctype-sizeof _cl_float)) 'raw))
  (memcpy temp inData (* signalLength (ctype-sizeof _cl_float)))
  (set! levelsDone 0)
  (let/ec break
    (let loop ()
      (when (< levelsDone actualLevels)
        (set! curLevels (if (< levels maxLevelsOnDevice) levels maxLevelsOnDevice))
        (set! curSignalLength (if (= levelsDone 0) signalLength (arithmetic-shift 1 levels)))
        (set! groupSize (/ (arithmetic-shift 1 curLevels) 2))
        (set! totalLevels levels)
        (runDwtHaar1DKernel)
        (if (<= levels maxLevelsOnDevice)
            (begin
              (ptr-set! dOutData _cl_float 0 (ptr-ref dPartialOutData _cl_float 0))
              (memcpy hOutData dOutData (* (arithmetic-shift 1 curLevels) (ctype-sizeof _cl_float)))
              (memcpy dOutData (arithmetic-shift 1 curLevels) hOutData (arithmetic-shift 1 curLevels) (* (- signalLength (arithmetic-shift 1 curLevels)) (ctype-sizeof _cl_float)))
              (break))
            (begin
              (set! levels (- levels maxLevelsOnDevice))
              (memcpy hOutData dOutData (* curSignalLength (ctype-sizeof _cl_float)))
              (memcpy inData dPartialOutData (* (arithmetic-shift 1 levels) (ctype-sizeof _cl_float)))
              (set! levelsDone (+ levelsDone maxLevelsOnDevice))))
        (loop))))
  (memcpy inData temp (* signalLength (ctype-sizeof _cl_float)))
  (free temp))

(define (setup)
  (setupDwtHaar1d)
  (set! setupTime (time-real setupCL)))

(define (run)
  (set! totalKernelTime (time-real runCLKernels)))

(define (verify-results)
  (define verified (calApproxFinalOnHost))
  (printf "~n~a~n" (if verified "Passed" "Failed")))

(define (print-stats)
  (printf "~nSignalLength: ~a, Setup Time: ~a, Kernel Time: ~a, Total Time: ~a~n"
          signalLength
          (real->decimal-string setupTime 3) 
          (real->decimal-string totalKernelTime 3)
          (real->decimal-string (+ setupTime totalKernelTime) 3)))

(define (cleanup)
  (clReleaseKernel kernel)
  (clReleaseProgram program)
  (clReleaseMemObject inDataBuf)
  (clReleaseMemObject dOutDataBuf)
  (clReleaseMemObject dPartialOutDataBuf)
  (clReleaseCommandQueue commandQueue)
  (clReleaseContext context)
  (free inData)
  (free dOutData)
  (free dPartialOutData)
  (free hOutData))

(setup)
(run)
(verify-results)
(cleanup)
(print-stats)