examples/pps/scan.ss
#lang scheme
(require (planet jaymccarthy/opencl/wrap)
         scheme/foreign
         scheme/runtime-path)
(unsafe!)

(define (cvector->vector cv)
  (list->vector (cvector->list cv)))

(define (power-of-two? x)
  (define n (inexact->exact x))
  (zero? (bitwise-and n (sub1 n))))

(define (pow-two x)
  (floor (/ (exp x) (exp 2))))

;;;;;
(define GROUP-SIZE 256)
(define NUM_BANKS 16)
(define MAX_ERROR 1e-7)

(define iterations 1000)
(define count (* 1024 1024))

(define float-data (malloc _float count 'raw))
(for ([i (in-range count)])
  (ptr-set! float-data _float i (* 10 (random))))

(define devices (platform-devices #f 'CL_DEVICE_TYPE_GPU))
(define device-id (cvector-ref devices 0))

(define max-workgroup-size (device-info device-id 'CL_DEVICE_MAX_WORK_GROUP_SIZE))

(set! GROUP-SIZE (min GROUP-SIZE max-workgroup-size))

(define vendor-name (device-info device-id 'CL_DEVICE_VENDOR))
(define device-name (device-info device-id 'CL_DEVICE_NAME))

(printf "Connecting to ~a ~a~n" vendor-name device-name)

(printf "Loading program~n")
(define-runtime-path program-source-path "scan_kernel.cl")
(define program-source (file->bytes program-source-path))

(define context (devices->context (vector device-id)))
(define queue (make-command-queue context device-id empty))
(define program (make-program/source context (vector program-source)))
(program-build! program (vector device-id) #"")

(define kernel-names 
  (list #"PreScanKernel"
        #"PreScanStoreSumKernel"
        #"PreScanStoreSumNonPowerOfTwoKernel"
        #"PreScanNonPowerOfTwoKernel"
        #"UniformAddKernel"))
(define kernels
  (for/hash ([kn (in-list kernel-names)])
    (define k (program-kernel program kn))
    (set! GROUP-SIZE 
          (min GROUP-SIZE (kernel-work-group-info k device-id 'CL_KERNEL_WORK_GROUP_SIZE)))
    (values kn k)))

(printf "Setting up buffers~n")
(define buffer-size (* (ctype-sizeof _float) count))
(define input (make-buffer context 'CL_MEM_READ_WRITE buffer-size #f))

(define inputw-evt (enqueue-write-buffer! queue input 'CL_TRUE 0 buffer-size float-data (vector)))
(event-release! inputw-evt)

(define output (make-buffer context 'CL_MEM_READ_WRITE buffer-size #f))

(define result (malloc _float count 'raw))
(memset result 0 0 count _float)

(define outputw-evt (enqueue-write-buffer! queue output 'CL_TRUE 0 buffer-size result (vector)))
(event-release! outputw-evt)

(define scan-partial-sums #f)
(define elements-allocated 0)
(define levels-allocated 0)

(define-syntax-rule (while test body)
  (let loop ()
    (when test
      body
      (loop))))
(define-syntax-rule (do-while body test)
  (begin body (while test body)))
(define-syntax-rule (++ id)
  (begin0 id
          (set! id (add1 id))))
(define (create-partial-sum-buffers count)
  (set! elements-allocated count)
  (local [(define group-size GROUP-SIZE)
          (define element-count count)
          (define level 0)]
    (do-while
     (local [(define group-count (max 1 (ceiling (/ element-count (* 2 group-size)))))]
       (when (> group-count 1)
         (++ level))
       (set! element-count group-count))
     (> element-count 1))
    (set! scan-partial-sums (make-vector level #f))
    (set! levels-allocated level)
    (set! element-count count)
    (set! level 0)
    (do-while
     (local [(define group-count (max 1 (ceiling (/ element-count (* 2 group-size)))))]
       (when (> group-count 1)
         (local [(define bufer-size (* group-count (ctype-sizeof _float)))]
           (vector-set! scan-partial-sums
                        (++ level)
                        (make-buffer context 'CL_MEM_READ_WRITE buffer-size #f))))
       (set! element-count group-count))
     (> element-count 1))))

(printf "Creating partial sums~n")
(create-partial-sum-buffers count)

(define (pre-scan global local-work shared output-data input-data n group-index base-index)
  (define k (hash-ref kernels #"PreScanKernel"))
  (set-kernel-arg:_cl_mem! k 0 output-data)
  (set-kernel-arg:_cl_mem! k 1 input-data)
  (set-kernel-arg:local! k 2 shared)
  (set-kernel-arg:_cl_int! k 3 group-index)
  (set-kernel-arg:_cl_int! k 4 base-index)
  (set-kernel-arg:_cl_int! k 5 n)
  (enqueue-nd-range-kernel! queue k 1 global local-work (vector)))

(define (pre-scan-store-sum global local-work shared output-data input-data partial-sums n group-index base-index)
  (define k (hash-ref kernels #"PreScanStoreSumKernel"))
  (set-kernel-arg:_cl_mem! k 0 output-data)
  (set-kernel-arg:_cl_mem! k 1 input-data)
  (set-kernel-arg:_cl_mem! k 2 partial-sums)
  (set-kernel-arg:local! k 3 shared)
  (set-kernel-arg:_cl_int! k 4 group-index)
  (set-kernel-arg:_cl_int! k 5 base-index)
  (set-kernel-arg:_cl_int! k 6 n)
  (enqueue-nd-range-kernel! queue k 1 global local-work (vector)))

(define (pre-scan-store-sum-non-power-of-two global local-work shared output-data input-data partial-sums n group-index base-index)
  (define k (hash-ref kernels #"PreScanStoreSumNonPowerOfTwoKernel"))
  (set-kernel-arg:_cl_mem! k 0 output-data)
  (set-kernel-arg:_cl_mem! k 1 input-data)
  (set-kernel-arg:_cl_mem! k 2 partial-sums)
  (set-kernel-arg:local! k 3 shared)
  (set-kernel-arg:_cl_int! k 4 group-index)
  (set-kernel-arg:_cl_int! k 5 base-index)
  (set-kernel-arg:_cl_int! k 6 n)
  (enqueue-nd-range-kernel! queue k 1 global local-work (vector)))

(define (pre-scan-non-power-of-two global local-work shared output-data input-data n group-index base-index)
  (define k (hash-ref kernels #"PreScanNonPowerOfTwoKernel"))
  (set-kernel-arg:_cl_mem! k 0 output-data)
  (set-kernel-arg:_cl_mem! k 1 input-data)
  (set-kernel-arg:local! k 2 shared)
  (set-kernel-arg:_cl_int! k 3 group-index)
  (set-kernel-arg:_cl_int! k 4 base-index)
  (set-kernel-arg:_cl_int! k 5 n)
  (enqueue-nd-range-kernel! queue k 1 global local-work (vector)))

(define (uniform-add global local-work output-data partial-sums n group-offset base-index)
  (define k (hash-ref kernels #"UniformAddKernel"))
  (set-kernel-arg:_cl_mem! k 0 output-data)
  (set-kernel-arg:_cl_mem! k 1 partial-sums)
  (set-kernel-arg:local! k 2 (ctype-sizeof _float))
  (set-kernel-arg:_cl_int! k 3 group-offset)
  (set-kernel-arg:_cl_int! k 4 base-index)
  (set-kernel-arg:_cl_int! k 5 n)
  (enqueue-nd-range-kernel! queue k 1 global local-work (vector)))

(define (pre-scan-buffer-rec output-data input-data max-group-size max-work-item-count element-count level)
  (define group-size max-group-size)
  (define group-count (max 1.0 (ceiling (/ element-count (* 2.0 group-size)))))
  (define work-item-count
    (min max-work-item-count
         (cond [(> group-size 1)
                group-size]
               [(power-of-two? element-count)
                (/ element-count 2)]
               [else
                (pow-two element-count)])))
  (define element-count-per-group (* work-item-count 2))
  (define last-group-element-count 
    (- element-count (* (sub1 group-count) element-count-per-group)))
  (define remaining-work-item-count
    (min max-work-item-count (max 1.0 (/ last-group-element-count 2))))
  (define remainder 0)
  (define last-shared 0)
  (unless (= last-group-element-count 
             element-count-per-group)
    (set! remainder 1)
    (unless (power-of-two? last-group-element-count)
      (set! remaining-work-item-count
            (min max-work-item-count
                 (pow-two last-group-element-count))))
    (set! last-shared 
          (* (ctype-sizeof _float)
             2
             (+ remaining-work-item-count 
                (/ (* 2 remaining-work-item-count)
                   NUM_BANKS)))))
  
  (local [(define global (vector (inexact->exact (* (max 1 (- group-count remainder)) work-item-count)) 1))
          (define local-work (vector (inexact->exact work-item-count) 1))
          (define shared
            (* (ctype-sizeof _float)
               (+ element-count-per-group
                  (/ element-count-per-group NUM_BANKS))))]
    (cond
      [(> group-count 1)
       (local [(define partial-sums (vector-ref scan-partial-sums level))]
         (pre-scan-store-sum global local-work shared
                             output-data input-data
                             partial-sums 
                             (inexact->exact (* work-item-count 2))
                             0 0)
         (unless (zero? remainder)
           (local [(define last-global (vector (* 1 remaining-work-item-count) 1))
                   (define last-local (vector remaining-work-item-count 1))]
             (pre-scan-store-sum-non-power-of-two
              last-global last-local last-shared
              output-data input-data partial-sums
              (inexact->exact last-group-element-count)
              (inexact->exact (sub1 group-count))
              (inexact->exact (- element-count last-group-element-count)))))
         (pre-scan-buffer-rec partial-sums partial-sums 
                              max-group-size max-work-item-count
                              group-count (add1 level))
         (uniform-add global local-work output-data partial-sums
                      (inexact->exact (- element-count last-group-element-count))
                      0 0)
         (unless (zero? remainder)
           (local [(define last-global (vector (* 1 remaining-work-item-count) 1))
                   (define last-local (vector remaining-work-item-count 1))]
             (uniform-add last-global last-local
                          output-data partial-sums
                          (inexact->exact last-group-element-count)
                          (inexact->exact (sub1 group-count))
                          (inexact->exact (- element-count last-group-element-count))))))]
      [(power-of-two? element-count)
       (pre-scan global local-work shared output-data input-data
                 (inexact->exact (* 2 work-item-count)) 0 0)]
      [else
       (pre-scan-non-power-of-two global local-work shared output-data input-data
                                  (inexact->exact element-count) 0 0)])))

(define (pre-scan-buffer output-data input-data max-group-size max-work-item-count element-count)
  (pre-scan-buffer-rec output-data input-data max-group-size max-work-item-count element-count 0))

(printf "Prescanning~n")
(pre-scan-buffer output input GROUP-SIZE GROUP-SIZE count)

(printf "Starting timing run of ~a iterations~n" iterations)
(define t0 (current-inexact-milliseconds))
(for ([i (in-range iterations)])
  (pre-scan-buffer output input GROUP-SIZE GROUP-SIZE count))

(command-queue-finish! queue)
(define t1 (current-inexact-milliseconds))

(define t (- t1 t0))
(printf "Exec Time: ~a ms~n" (/ t iterations))
(printf "Throughput: ~a GB/sec~n" (/ (* 1e-9 buffer-size iterations) t))

(define outputr-evt (enqueue-read-buffer! queue output 'CL_TRUE 0 buffer-size result (vector)))
(event-release! outputr-evt)

(define reference (malloc _float count 'raw))

(define (scan-reference reference input count)
  (define total-sum 0)
  (ptr-set! reference _float 0 0.0)
  (for ([i (in-range count)])
    (define last-i (ptr-ref input _float (sub1 i)))
    (set! total-sum (+ total-sum last-i))
    (ptr-set! reference _float i (+ last-i (ptr-ref reference _float (sub1 i)))))
  (unless (= total-sum (ptr-ref reference _float (sub1 count)))
    (fprintf (current-error-port) "Warning: Exceeding single-precision accuracy. Scan will be inaccurate~n")))

(scan-reference reference float-data count)

(define error -inf.0)
(for ([i (in-range count)])
  (define diff 
    (abs (- (ptr-ref reference _float i)
            (ptr-ref result _float i))))
  (set! error (max error diff)))

(printf "Maximum error: ~a~n" error)

(define (release-partial-sums)
  (for ([i (in-range levels-allocated)])
    (memobj-release! (vector-ref scan-partial-sums i)))
  (set! elements-allocated 0)
  (set! levels-allocated 0))

(release-partial-sums)
(free result)
(memobj-release! output)
(memobj-release! input)
(for ([k (in-hash-values kernels)])
  (kernel-release! k))
(program-release! program)
(command-queue-release! queue)
(context-release! context)
(free float-data)