samples/atiSamples/FastWalshTransform/FastWalshTransform.rkt
#lang racket
(require "../../../c.rkt"
         "../atiUtils/utils.rkt"
         ffi/unsafe
         ffi/cvector
         ffi/unsafe/cvector)

(define setupTime -1)
(define totalKernelTime -1)
(define devices #f)
(define context #f)
(define commandQueue #f)
(define program #f)
(define kernel #f)
(define length 1024)
(define input #f)
(define output #f)
(define verificationInput #f)
(define inputBuffer #f)

(define (setupFastWalshTransform)
  (define inputSizeBytes (* length (ctype-sizeof _cl_float)))
  (set! input (malloc inputSizeBytes 'raw))
  (set! output (malloc inputSizeBytes 'raw))
  (fill-random:_cl_float input length 255)
  (set! verificationInput (malloc inputSizeBytes 'raw))
  (memcpy verificationInput input inputSizeBytes))

(define (setupCL)
  (define size (* length (ctype-sizeof _cl_float)))
  (set!-values (devices context commandQueue program) (init-cl "FastWalshTransform_Kernels.cl" #:queueProperties 'CL_QUEUE_PROFILING_ENABLE))
  (set! inputBuffer (clCreateBuffer context 'CL_MEM_READ_WRITE size #f))
  (set! kernel (clCreateKernel program #"fastWalshTransform")))

(define (runCLKernels)
  (define size (* length (ctype-sizeof _cl_float)))
  (clEnqueueWriteBuffer commandQueue inputBuffer 'CL_TRUE 0 size input (make-vector 0))
  (define globalThreads (/ length 2))
  (define localThreads (optimum-threads kernel (cvector-ref devices 0) 256))
  (clSetKernelArg:_cl_mem kernel 0 inputBuffer)
  (define event #f)
  (let loop ([step 1])
    (when (< step length)
      (clSetKernelArg:_cl_int kernel 1 step)
      (set! event (clEnqueueNDRangeKernel commandQueue kernel 1 (vector globalThreads) (vector localThreads) (make-vector 0)))
      (clWaitForEvents (vector event))
      (clReleaseEvent event)
      (loop (arithmetic-shift step 1))))
  (clEnqueueReadBuffer commandQueue inputBuffer 'CL_TRUE 0 size output (make-vector 0)))

(define (fastWalshTransformCPUReference)
  (let loop ([step 1])
    (when (< step length)
      (define jump (arithmetic-shift step 1))
      (for ([group (in-range step)])
        (for ([pair (in-range group length jump)])
          (define match (+ pair step))
          (define T1 (ptr-ref verificationInput _cl_float pair))
          (define T2 (ptr-ref verificationInput _cl_float match))
          (ptr-set! verificationInput _cl_float pair (+ T1 T2))
          (ptr-set! verificationInput _cl_float match (- T1 T2))))
      (loop (arithmetic-shift step 1))))
    (compare output verificationInput length))

(define (setup)
  (setupFastWalshTransform)
  (set! setupTime (time-real setupCL)))

(define (run)
  (set! totalKernelTime (time-real runCLKernels)))

(define (verify-results)
  (define verified (fastWalshTransformCPUReference))
  (printf "~n~a~n" (if verified "Passed" "Failed")))

(define (print-stats)
  (printf "~nLength: ~a, Setup Time: ~a, Kernel Time: ~a, Total Time: ~a~n"
          length
          (real->decimal-string setupTime 3) 
          (real->decimal-string totalKernelTime 3)
          (real->decimal-string (+ setupTime totalKernelTime) 3)))

(define (cleanup)
  (clReleaseKernel kernel)
  (clReleaseProgram program)
  (clReleaseMemObject inputBuffer)
  (clReleaseCommandQueue commandQueue)
  (clReleaseContext context)
  (free input)
  (free output)
  (free verificationInput))

(setup)
(run)
(verify-results)
(cleanup)
(print-stats)