samples/atiSamples/FloydWarshall/FloydWarshall.rkt
#lang racket
(require "../../../c.rkt"
         "../atiUtils/utils.rkt"
         ffi/unsafe
         ffi/cvector
         ffi/unsafe/cvector)

(define setupTime -1)
(define totalKernelTime -1)
(define devices #f)
(define context #f)
(define commandQueue #f)
(define program #f)
(define kernel #f)
(define numNodes 256)
(define height 256)
(define width 256)
(define pathDistanceMatrix #f)
(define pathMatrix #f)
(define verificationPathDistanceMatrix #f)
(define verificationPathMatrix #f)
(define pathDistanceBuffer #f)
(define pathBuffer #f)

(define (setupFloydWarshall)
  (set! height numNodes)
  (set! width numNodes)
  (define matrixSizeBytes (* width height (ctype-sizeof _cl_uint)))
  (set! pathDistanceMatrix (malloc matrixSizeBytes 'raw))
  (set! pathMatrix (malloc matrixSizeBytes 'raw))
  (fill-random:_cl_uint pathDistanceMatrix (* width height) 200)
  (for ([i (in-range height)])
    (define iXWidth (* i width))
    (ptr-set! pathDistanceMatrix _cl_uint (+ iXWidth i) 0))
  (for ([i (in-range height)])
    (for ([j (in-range i)])
      (ptr-set! pathMatrix _cl_uint (+ j (* i width)) i)
      (ptr-set! pathMatrix _cl_uint (+ i (* j width)) j))
    (ptr-set! pathMatrix _cl_uint (+ i (* i width)) i))
  (set! verificationPathDistanceMatrix (malloc matrixSizeBytes 'raw))
  (set! verificationPathMatrix (malloc matrixSizeBytes 'raw))
  (memcpy verificationPathDistanceMatrix pathDistanceMatrix matrixSizeBytes)
  (memcpy verificationPathMatrix pathMatrix matrixSizeBytes))

(define (setupCL)
  (define size (* width height (ctype-sizeof _cl_uint)))
  (set!-values (devices context commandQueue program) (init-cl "FloydWarshall_Kernels.cl" #:queueProperties 'CL_QUEUE_PROFILING_ENABLE))
  (set! pathDistanceBuffer (clCreateBuffer context '(CL_MEM_READ_WRITE CL_MEM_USE_HOST_PTR) size pathDistanceMatrix))
  (set! pathBuffer (clCreateBuffer context '(CL_MEM_READ_WRITE CL_MEM_USE_HOST_PTR) size pathMatrix))
  (set! kernel (clCreateKernel program #"floydWarshallPass")))

(define (runCLKernels)
  (define localThreads (optimum-threads kernel (cvector-ref devices 0) width))
  (set! width localThreads)
  (set! height localThreads)
  (set! numNodes localThreads)
  (define numPasses width)
  (define globalThreads (* height width))
  (clSetKernelArg:_cl_mem kernel 0 pathDistanceBuffer)
  (clSetKernelArg:_cl_mem kernel 1 pathBuffer)
  (clSetKernelArg:_cl_uint kernel 2 width)
  (for ([i (in-range numPasses)])
    (clSetKernelArg:_cl_uint kernel 3 i)
    (define event (clEnqueueNDRangeKernel commandQueue kernel 1 (vector globalThreads) (vector localThreads) (make-vector 0)))
    (clWaitForEvents (vector event))
    (clReleaseEvent event))
  (clEnqueueReadBuffer commandQueue pathBuffer 'CL_TRUE 0 (* width height (ctype-sizeof _cl_uint)) pathMatrix (make-vector 0))
  (clEnqueueReadBuffer commandQueue pathDistanceBuffer 'CL_TRUE 0 (* width height (ctype-sizeof _cl_uint)) pathDistanceMatrix (make-vector 0)))

(define (floydWarshallCPUReference)
  (define width numNodes)
  (for ([k (in-range numNodes)])
    (for ([y (in-range numNodes)])
      (define yXwidth (* y numNodes))
      (for ([x (in-range numNodes)])
        (define distanceYtoX (ptr-ref verificationPathDistanceMatrix _cl_uint (+ yXwidth x)))
        (define distanceYtoK (ptr-ref verificationPathDistanceMatrix _cl_uint (+ yXwidth k)))
        (define distanceKtoX (ptr-ref verificationPathDistanceMatrix _cl_uint (+ x (* k width))))
        (define indirectDistance (+ distanceYtoK distanceKtoX))
        (when (< indirectDistance distanceYtoX)
          (ptr-set! verificationPathDistanceMatrix _cl_uint (+ yXwidth x) indirectDistance)
          (ptr-set! verificationPathMatrix  _cl_uint (+ yXwidth x) k))
        (ptr-set! verificationPathDistanceMatrix _cl_uint (+ yXwidth x)
                  (min (ptr-ref verificationPathDistanceMatrix _cl_uint (+ yXwidth x))
                       (+ (ptr-ref verificationPathDistanceMatrix _cl_uint (+ yXwidth k))
                          (ptr-ref verificationPathDistanceMatrix _cl_uint (+ x (* k width)))))))))
  (define result #t)
  (for ([i (in-range (* height width))])
    (unless (= (ptr-ref verificationPathDistanceMatrix _cl_uint i)
              (ptr-ref pathDistanceMatrix _cl_uint i))
      (set! result #f)))
  result)

(define (setup)
  (setupFloydWarshall)
  (set! setupTime (time-real setupCL)))

(define (run)
  (set! totalKernelTime (time-real runCLKernels)))

(define (verify-results)
  (define verified (floydWarshallCPUReference))
  (printf "~n~a~n" (if verified "Passed" "Failed")))

(define (print-stats)
  (printf "~nNodes: ~a, Setup Time: ~a, Kernel Time: ~a, Total Time: ~a~n"
          numNodes
          (real->decimal-string setupTime 3) 
          (real->decimal-string totalKernelTime 3)
          (real->decimal-string (+ setupTime totalKernelTime) 3)))

(define (cleanup)
  (clReleaseKernel kernel)
  (clReleaseProgram program)
  (clReleaseMemObject pathDistanceBuffer)
  (clReleaseMemObject pathBuffer)
  (clReleaseCommandQueue commandQueue)
  (clReleaseContext context)
  (free pathDistanceMatrix)
  (free pathMatrix)
  (free verificationPathDistanceMatrix)
  (free verificationPathMatrix))

(setup)
(run)
(verify-results)
(cleanup)
(print-stats)