#lang racket (require "../../../c.rkt" "../atiUtils/utils.rkt" ffi/unsafe ffi/cvector ffi/unsafe/cvector) (define setupTime -1) (define totalKernelTime -1) (define devices #f) (define context #f) (define commandQueue #f) (define program #f) (define kernel #f) (define length 1024) (define input #f) (define output #f) (define verificationInput #f) (define inputBuffer #f) (define (setupFastWalshTransform) (define inputSizeBytes (* length (ctype-sizeof _cl_float))) (set! input (malloc inputSizeBytes 'raw)) (set! output (malloc inputSizeBytes 'raw)) (fill-random:_cl_float input length 255) (set! verificationInput (malloc inputSizeBytes 'raw)) (memcpy verificationInput input inputSizeBytes)) (define (setupCL) (define size (* length (ctype-sizeof _cl_float))) (set!-values (devices context commandQueue program) (init-cl "FastWalshTransform_Kernels.cl" #:queueProperties 'CL_QUEUE_PROFILING_ENABLE)) (set! inputBuffer (clCreateBuffer context 'CL_MEM_READ_WRITE size #f)) (set! kernel (clCreateKernel program #"fastWalshTransform"))) (define (runCLKernels) (define size (* length (ctype-sizeof _cl_float))) (clEnqueueWriteBuffer commandQueue inputBuffer 'CL_TRUE 0 size input (make-vector 0)) (define globalThreads (/ length 2)) (define localThreads (optimum-threads kernel (cvector-ref devices 0) 256)) (clSetKernelArg:_cl_mem kernel 0 inputBuffer) (define event #f) (let loop ([step 1]) (when (< step length) (clSetKernelArg:_cl_int kernel 1 step) (set! event (clEnqueueNDRangeKernel commandQueue kernel 1 (vector globalThreads) (vector localThreads) (make-vector 0))) (clWaitForEvents (vector event)) (clReleaseEvent event) (loop (arithmetic-shift step 1)))) (clEnqueueReadBuffer commandQueue inputBuffer 'CL_TRUE 0 size output (make-vector 0))) (define (fastWalshTransformCPUReference) (let loop ([step 1]) (when (< step length) (define jump (arithmetic-shift step 1)) (for ([group (in-range step)]) (for ([pair (in-range group length jump)]) (define match (+ pair step)) (define T1 (ptr-ref verificationInput _cl_float pair)) (define T2 (ptr-ref verificationInput _cl_float match)) (ptr-set! verificationInput _cl_float pair (+ T1 T2)) (ptr-set! verificationInput _cl_float match (- T1 T2)))) (loop (arithmetic-shift step 1)))) (compare output verificationInput length)) (define (setup) (setupFastWalshTransform) (set! setupTime (time-real setupCL))) (define (run) (set! totalKernelTime (time-real runCLKernels))) (define (verify-results) (define verified (fastWalshTransformCPUReference)) (printf "~n~a~n" (if verified "Passed" "Failed"))) (define (print-stats) (printf "~nLength: ~a, Setup Time: ~a, Kernel Time: ~a, Total Time: ~a~n" length (real->decimal-string setupTime 3) (real->decimal-string totalKernelTime 3) (real->decimal-string (+ setupTime totalKernelTime) 3))) (define (cleanup) (clReleaseKernel kernel) (clReleaseProgram program) (clReleaseMemObject inputBuffer) (clReleaseCommandQueue commandQueue) (clReleaseContext context) (free input) (free output) (free verificationInput)) (setup) (run) (verify-results) (cleanup) (print-stats)