main.ss
#lang scheme

(require (prefix-in ffi: "ffi.ss"))
(require (prefix-in convert: "converters/engine.ss"))

(require (only-in (planet vyzo/crypto) sha256))
(require net/base64)

(define prepare-hash 
  (λ (sql) (string-append "prep" (bytes->string/utf-8 (base64-encode (sha256 (string->bytes/utf-8 sql)))))))

(define make-cursor-name
  (let ([next-id 0])
    (λ (c)
      (begin0
        (send c escape-identifier (format "cursor~a" next-id))
        (set! next-id (+ 1 next-id))))))

(define (map-leaves proc i)
  (if (list? i)
      (map (λ (ii) (map-leaves proc ii)) i)
      (proc i)))

(define (maybe-protect result value)
;  (if (bytes? value) ; bytes are not copied, so result must not be collected!
      (ffi:protect-with! value result)
;      value))
  )

(define result%
  (class object%
    (inspect #f)
    (init-field result decode)
    (field [status (ffi:result-status result)]
           [n-rows (ffi:result-rows result)]
           [n-columns (ffi:result-columns result)])
    (define (warn message)
      (let ([message (format "~a ~a~n~a~n" message (ffi:result-status->string status) (ffi:result-error-message result))])
        (display message)
        message))
    (define (die message)
      (raise-user-error (warn message)))
    (case status
      [(ok tuples-ok) (void)]
      [(empty-query bad-response error) (die "result failed to init")]
      [(warning) (warn "init warning")])
    
;    (when (> n-rows 0) (display (format "***** Got ~s ~s~n" n-rows n-columns)))
    (super-new)
    (define/public (for-each row-handler)
      (ffi:result-for-each-row 
       (λ (row)
         (apply 
          row-handler
          (map (λ (value) (maybe-protect result value))
               (dict-map 
                row
                decode))))
       result 
       n-rows 
       n-columns))
    (define/public (get-matrix)
      (map-leaves 
       (λ (pair) (maybe-protect result (decode (car pair) (cdr pair))))
       (ffi:result-matrix result n-rows n-columns)))
    
    (define/public (param-info)
      ; for describing prepared statements
      (build-vector (ffi:n-params result) (λ (column) (ffi:param-type result column))))
    
    (field [fields (ffi:result-fields result)])))

(define cursor%
  (class object%
    (inspect #f)
    (init-field connection sql params)
    (super-new)
    (field [name (make-cursor-name connection)])
;    (display (format "Name is ~s~n" name))
    (send/apply connection p-exec (string-append "DECLARE " name " CURSOR FOR " sql) params)

    (define (inner-fetch-one)
      (let ([result (send connection p-exec (string-append "FETCH " name))])
        result))
    
    (define/public (fetch-one)
      (let ([result (inner-fetch-one)])
        (if (not (= (get-field n-rows result) 1)) (raise-user-error "No more rows to fetch. use for-each dumbass")
            result)))
    
    (define/public (fetch-many n)
      (send connection p-exec (format "FETCH ~a ~s" n name)))
    
    (define/public (fetch-rest)
      (send connection p-exec (string-append "FETCH ALL " name)))
    
    ; this is probably the sanest one to use
    ; memory kind and logically separated by rows
    ; not by results
    (define/public (for-each row-handler)
      (let loop ([result (inner-fetch-one)])
        (if (= (get-field n-rows result) 0) (void)
            (begin
              (send result for-each row-handler)
              (loop (inner-fetch-one))))))))

(define prepared-statement%
  (class object%
    (inspect #f)
    (init-field connection name sql cast-able?)
    (super-new)
    
    (send connection inner-prepare name sql)
    (define oids (send (send connection describe name) param-info))
    
    (define (check-params params)
      (let loop ([params params] [column 0])
        (if (null? params) (void)
            (begin
              (cond
                [(null? params) 'ok]
                [(> column (vector-length oids)) (error "Not enough parameters provided to this statement!")]
                [else
                 (let ([param-oid (send connection divine (car params))])
                   (cond
                     [(eq? param-oid (vector-ref oids column)) 'ok]
                     [(cast-able? param-oid (vector-ref oids column)) 'ok]
                     [else (error (format "The OID for ~a(~a) could not match template OID ~a~n" (car params) param-oid 
                                          (vector-ref oids column)))])
                   (loop (cdr params) (+ column 1)))])))))
    
    (define/public (exec . params)
      (check-params params)
      (send connection inner-query-prepared name oids params))))

(define-syntax this-or
  (syntax-rules ()
    [(_ a b ...) (if a a (begin b ...))]))

(define-syntax define-send
  (syntax-rules ()
    [(_ method type object)
     (define method
       (let ([generic-method (generic type method)])
         (λ args
           (send-generic object generic-method . args))))]))

(define (my-vl-map proc v l)
  (let ([result (make-vector (vector-length v) 0)])
    (let loop ([index 0] [l l])
      (cond
        [(= index (vector-length v)) result]
        [(null? l) (error "List must have same number of members as vector.")]
        [else (begin
                (vector-set! result index (proc (vector-ref v index) (car l)))
                (loop (+ index 1) (cdr l)))]))))
 
(define connection%
  (class convert:base-engine%
    (inspect #f)
    (init-field handle input output)
    (inherit-field oid-size)
    (inherit set-vector-info!
             encode decode divine)
    
    (super-new [integer-time (equal? "on" (ffi:connection-parameter handle "integer_datetimes"))])
    
    (field [cast-table (make-immutable-hash null)]
           [oid-names (make-immutable-hash null)])

    (define (cast-able? from to)
;      (display (format "cast ~s to ~s~n" from to))
      (hash-ref cast-table (list from to) (λ () #f)))

    (define prepare-cache (make-immutable-hash null))
    
    (when (not (ffi:set-client-encoding! handle "utf-8"))
      (error "We only roll in utf-8 dudes."))
    
    (define version (ffi:protocol-version handle))
    (when (< version 3) (error "Please upgrade!"))
 
    (define/public (initialize)
      ; yay, introspection and side effects galore!
      (set! 
       cast-table
       (make-immutable-hash
        (let ([result null])
          (send (exec "SELECT castsource,casttarget,castfunc FROM pg_cast") for-each
                (λ (source target func)
                  (when (not func) (error "Er, if not func cast-able? fails. wtf is this?"))
                  (set! result (cons (cons (list source target) func) result))))
          result)))
      
      (send (exec "SELECT oid,typname FROM pg_type") for-each
            (λ (oid name)
              (set! oid-names (hash-set oid-names name oid))))
      
      (set! oid-size (hash-ref oid-names 'oid))
      
      (send (exec "SELECT oid,typelem,typlen FROM pg_type") for-each
            (λ (oid element-oid length)
              (when (> element-oid 0)
                (let ([length (if (> length 0) length #f)])
                  (set-vector-info! oid element-oid length)))))
      
      ; now we have meaningful OID names, we can handle vectors of types (incl fixed)
      ; and we can tell whether the OID of a parameter would cast to the OID of its
      ; designated type.
     )
    
    (define (output-wait)
      (assert-status "flushing output" handle)
      (when (ffi:flushing? handle)
        (sync output)
        (output-wait)))
    
    (define (input-wait)
      (assert-status "polling input" handle)
      (sync/timeout 0 input)
      (check-send (ffi:consume-input handle))
      (if (ffi:is-busy? handle) (input-wait)
          ; only ever call start-result when not is-busy
          (begin0
            (ffi:start-result handle)
            (assert-status "polling input" handle))))
    
;    (define (get-result)
;      (assert-status "get result" handle)
;      (apply
;       values ; but there will only ever be one value because we don't use lameo PQexec/PQsendQuery.
;       (let loop ([result (input-wait)] [results null])
;         (if result (loop (input-wait) (cons (new result% [result result]) results))
;             (reverse results)))))
     
    (define (get-result)
      (assert-status "get result" handle)
      (let ([result (input-wait)])
      (begin0
        (new result% [result result] [decode (λ (oid param) (decode oid param))])
        (when (input-wait) (error "Got two results somehow?")))))
   
    (define (check-send ok)
      (if ok (void)
          (pq-error "We could not send a query!" handle)))
    
    (define/public (escape-identifier identifier)
      (ffi:escape-identifier handle identifier))
    
    (define/public (select sql . params)
      (new cursor% [connection this] [sql sql] [params params]))
    
    (define/public (exec sql . params)
      (let ([oids (map (λ (param) (divine param)) params)])
        (check-send (ffi:send-query-params handle sql oids (map (λ (oid param) (encode oid param)) oids params))))
      (output-wait)
      (get-result))
    
    (define/public (p-exec sql . params)
      (let ([prep (prepare sql)])
        (send/apply prep exec params)))
    
    (define/public (describe name)
      (check-send (ffi:send-describe-prepared handle name))
      (output-wait)
      (get-result))
    
    (define/public (prepare sql)
      (let ([name (prepare-hash sql)])
        (hash-ref 
         prepare-cache name
         (λ ()
           (let ([prepared (new prepared-statement% 
                                [connection this]
                                [name name]
                                [sql sql]
                                [cast-able? cast-able?])])
                                
             (set! prepare-cache (hash-set prepare-cache name prepared))
             prepared)))))
    
    (define/public (inner-prepare name sql)
      (check-send (ffi:send-prepare handle name sql))
      (output-wait)
      (get-result)) ; describe-result?
    
    (define/public (inner-query-prepared name oids params)
      (check-send (ffi:send-query-prepared handle name (my-vl-map (λ (oid param) (encode oid param)) oids params)))
      (output-wait)
      (get-result))
    ))

(define-syntax with-transaction
  (syntax-rules ()
    [(_ connection commands ...)
     (dynamic-wind
      (λ () (send connection p-exec "BEGIN"))
      (λ () commands ...)
      (λ () (send connection p-exec "COMMIT")))]))


; After sending any command or data on a nonblocking connection, call PQflush. If it returns 1, wait for the socket to be write-ready and call it again; repeat until it returns 0. Once PQflush returns 0, wait for the socket to be read-ready and then read the response as described above.

      ;this is redundant...
(define (output-wait handle output)
  (when (ffi:flushing? handle)
      (begin
        (sync output)
        (output-wait handle output))))

(define (pq-error message handle)
  (error (format "~a [~s]~n" message (ffi:error-message handle))))

(define (assert-status message handle)
  (when (eq? (ffi:connection-status handle) 'bad) (pq-error message handle)))

(define connect
  (make-keyword-procedure
   (λ (names values)
     (let ([handle (ffi:do-connect-start (ffi:make-parameters names values))])
       (assert-status "Connection has failed" handle)
       (ffi:set-nonblocking! handle #t)
       (ffi:set-error-verbosity! handle 'verbose)
       (let-values ([(input output) (ffi:socket-to-ports (ffi:connection-socket handle) "postgresql")])
         (sync output)
         (let loop ()
           (let ([val (ffi:connect-poll handle)])
             (case val
               [(reading) (begin (sync input) (loop))]
               [(writing) (begin (sync output) (loop))]
               [(failed) (pq-error "Connection failed while polling" handle)]
               [(ok) (new connection% [handle handle] [input input] [output output])]
               [else (error (format "Uh...huh? ~a~n" val))]))))))))

(provide connect with-transaction)
(provide 
 (rename-out (ffi:with-tracing-to with-tracing-to))
; (rename-out (ffi:escape-identifier escape-identifier))
 )