#lang racket (require "../../../c.rkt" "../atiUtils/utils.rkt" ffi/unsafe ffi/cvector ffi/unsafe/cvector) (define setupTime -1) (define totalKernelTime -1) (define devices #f) (define context #f) (define commandQueue #f) (define program #f) (define kernel #f) (define signalLength (arithmetic-shift 1 10)) (define totalLevels 0) (define kernelWorkGroupSize 0) (define curSignalLength 0) (define groupSize 0) (define levelsDone 0) (define inData #f) (define dOutData #f) (define dPartialOutData #f) (define hOutData #f) (define inDataBuf #f) (define dOutDataBuf #f) (define dPartialOutDataBuf #f) (define maxLevelsOnDevice 0) (define (calApproxFinalOnHost) (define result #t) (define size (* signalLength (ctype-sizeof _cl_float))) (define tempOutData (malloc size 'raw)) (memcpy tempOutData inData size) (for ([i (in-range signalLength)]) (ptr-set! tempOutData _cl_float i (/ (ptr-ref tempOutData _cl_float i) (sqrt signalLength)))) (define length signalLength) (let loop () (when (> length 1) (for ([i (in-range (/ length 2))]) (define data0 (ptr-ref tempOutData _cl_float (* 2 i))) (define data1 (ptr-ref tempOutData _cl_float (add1 (* 2 i)))) (ptr-set! hOutData _cl_float i (/ (+ data0 data1) (sqrt 2.0))) (ptr-set! hOutData _cl_float (+ i (/ length 2)) (/ (- data0 data1) (sqrt 2)))) (memcpy tempOutData hOutData size) (set! length (arithmetic-shift length -1)) (loop))) (let/ec break (for ([i (in-range signalLength)]) (define val1 (ptr-ref dOutData _cl_float i)) (define val2 (ptr-ref hOutData _cl_float i)) (when (> (abs (- val1 val2)) 0.01) (set! result #f) (break)))) (free tempOutData) result) (define (getLevels length levels) (define val 0) (let/ec break (for ([i (in-range 24)]) (when (= length (arithmetic-shift 1 i)) (set! val i) (break)))) val) (define (setupDwtHaar1d) (define size (* signalLength (ctype-sizeof _cl_float))) (set! inData (malloc size 'raw)) (for ([i (in-range signalLength)]) (ptr-set! inData _cl_float i (exact->inexact (random 10)))) (set! dOutData (malloc size 'raw)) (memset dOutData 0 size) (set! dPartialOutData (malloc size 'raw)) (memset dPartialOutData 0 size) (set! hOutData (malloc size 'raw)) (memset hOutData 0 size)) (define (setupCL) (define size (* signalLength (ctype-sizeof _cl_float))) (set!-values (devices context commandQueue program) (init-cl "DwtHaar1D_Kernels.cl" #:queueProperties 'CL_QUEUE_PROFILING_ENABLE)) (set! inDataBuf (clCreateBuffer context '(CL_MEM_READ_ONLY CL_MEM_USE_HOST_PTR) size inData)) (set! dOutDataBuf (clCreateBuffer context 'CL_MEM_WRITE_ONLY size #f)) (set! dPartialOutDataBuf (clCreateBuffer context 'CL_MEM_WRITE_ONLY size #f)) (set! kernel (clCreateKernel program #"dwtHaar1D")) (set! kernelWorkGroupSize (optimum-threads kernel (cvector-ref devices 0) 256))) (define (runDwtHaar1DKernel) (define globalThreads (vector (arithmetic-shift curSignalLength -1))) (define localThreads (vector groupSize)) (clEnqueueWriteBuffer commandQueue inDataBuf 'CL_TRUE 0 (* curSignalLength (ctype-sizeof _cl_float)) inData (make-vector 0)) (clSetKernelArg:_cl_mem kernel 0 inDataBuf) (clSetKernelArg:_cl_mem kernel 1 dOutDataBuf) (clSetKernelArg:_cl_mem kernel 2 dPartialOutDataBuf) (clSetKernelArg:local kernel 3 (* (vector-ref localThreads 0) 2 (ctype-sizeof _cl_float))) (clSetKernelArg:_cl_uint kernel 4 totalLevels) (clSetKernelArg:_cl_uint kernel 5 curSignalLength) (clSetKernelArg:_cl_uint kernel 6 levelsDone) (clSetKernelArg:_cl_uint kernel 7 maxLevelsOnDevice) (define event (clEnqueueNDRangeKernel commandQueue kernel 1 globalThreads localThreads (make-vector 0))) (clWaitForEvents (vector event)) (clReleaseEvent event) (clEnqueueReadBuffer commandQueue dOutDataBuf 'CL_TRUE 0 (* signalLength (ctype-sizeof _cl_float)) dOutData (make-vector 0)) (clEnqueueReadBuffer commandQueue dPartialOutDataBuf 'CL_TRUE 0 (* signalLength (ctype-sizeof _cl_float)) dPartialOutData (make-vector 0))) (define (runCLKernels) (define levels (getLevels signalLength 0)) (define actualLevels levels) (define curLevels 0) (set! maxLevelsOnDevice (add1 (inexact->exact (truncate (/ (log kernelWorkGroupSize) (log 2.0)))))) (define temp (malloc (* signalLength (ctype-sizeof _cl_float)) 'raw)) (memcpy temp inData (* signalLength (ctype-sizeof _cl_float))) (set! levelsDone 0) (let/ec break (let loop () (when (< levelsDone actualLevels) (set! curLevels (if (< levels maxLevelsOnDevice) levels maxLevelsOnDevice)) (set! curSignalLength (if (= levelsDone 0) signalLength (arithmetic-shift 1 levels))) (set! groupSize (/ (arithmetic-shift 1 curLevels) 2)) (set! totalLevels levels) (runDwtHaar1DKernel) (if (<= levels maxLevelsOnDevice) (begin (ptr-set! dOutData _cl_float 0 (ptr-ref dPartialOutData _cl_float 0)) (memcpy hOutData dOutData (* (arithmetic-shift 1 curLevels) (ctype-sizeof _cl_float))) (memcpy dOutData (arithmetic-shift 1 curLevels) hOutData (arithmetic-shift 1 curLevels) (* (- signalLength (arithmetic-shift 1 curLevels)) (ctype-sizeof _cl_float))) (break)) (begin (set! levels (- levels maxLevelsOnDevice)) (memcpy hOutData dOutData (* curSignalLength (ctype-sizeof _cl_float))) (memcpy inData dPartialOutData (* (arithmetic-shift 1 levels) (ctype-sizeof _cl_float))) (set! levelsDone (+ levelsDone maxLevelsOnDevice)))) (loop)))) (memcpy inData temp (* signalLength (ctype-sizeof _cl_float))) (free temp)) (define (setup) (setupDwtHaar1d) (set! setupTime (time-real setupCL))) (define (run) (set! totalKernelTime (time-real runCLKernels))) (define (verify-results) (define verified (calApproxFinalOnHost)) (printf "~n~a~n" (if verified "Passed" "Failed"))) (define (print-stats) (printf "~nSignalLength: ~a, Setup Time: ~a, Kernel Time: ~a, Total Time: ~a~n" signalLength (real->decimal-string setupTime 3) (real->decimal-string totalKernelTime 3) (real->decimal-string (+ setupTime totalKernelTime) 3))) (define (cleanup) (clReleaseKernel kernel) (clReleaseProgram program) (clReleaseMemObject inDataBuf) (clReleaseMemObject dOutDataBuf) (clReleaseMemObject dPartialOutDataBuf) (clReleaseCommandQueue commandQueue) (clReleaseContext context) (free inData) (free dOutData) (free dPartialOutData) (free hOutData)) (setup) (run) (verify-results) (cleanup) (print-stats)