main.ss
#lang scheme

;(error "do not import me")

(require (prefix-in ffi: "ffi.ss")
         (prefix-in convert: "converters/engine.ss")
         
;         (prefix-in log: (planet synx/log))
         (only-in (planet vyzo/crypto) sha256)
         net/base64)

(define prepare-hash 
  (λ (sql) (string-append "prep" (bytes->string/utf-8 (base64-encode (sha256 (string->bytes/utf-8 sql)))))))

(define make-cursor-name
  (let ([next-id 0])
    (λ (c)
      (begin0
        (send c escape-identifier (format "cursor~a" next-id))
        (set! next-id (+ 1 next-id))))))

(define (map-leaves proc i)
  (if (list? i)
      (map (λ (ii) (map-leaves proc ii)) i)
      (proc i)))

(define (maybe-protect result value)
;  (if (bytes? value) ; bytes are not copied, so result must not be collected!
      (ffi:protect-with! value result)
;      value))
  )

(define result%
  (class object%
    (inspect #f)
    (init-field sql result decode)
    (field [status (ffi:result-status result)]
           [n-rows (ffi:result-rows result)]
           [n-columns (ffi:result-columns result)])
    (define (warn message)
      (let ([message (format "~a~n~a ~a~n~a~n" sql message (ffi:result-status->string status) (ffi:result-error-message result))])
        (display message)
        message))
    (define (die message)
      (raise-user-error (warn message)))
    (case status
      [(ok tuples-ok) (void)]
      [(empty-query bad-response error) (die "result failed to init")]
      [(warning) (warn "init warning")])
    
;    (when (> n-rows 0) (display (format "***** Got ~s ~s~n" n-rows n-columns)))
    (super-new)
    (define/public (for-each row-handler)
      (ffi:result-for-each-row 
       (λ (row)
         (apply 
          row-handler
          (map (λ (value) (maybe-protect result value))
               (dict-map 
                row
                decode))))
       result 
       n-rows 
       n-columns))
    
    (define/public (map-row row-handler)
      (ffi:result-map
       (λ (row)
         (apply 
          row-handler 
          (map (λ (value) (maybe-protect result value))
               (dict-map 
                row
                decode))))
       result
       n-rows
       n-columns))
    
    (define/public (fold proc init)
      (ffi:result-fold
       (λ (f-result row)
         (apply proc f-result 
                (map (λ (value) (maybe-protect result value))
                     (dict-map
                      row 
                      decode))))
       init
       result
       n-rows
       n-columns))
    
    (define/public (get-matrix)
      (map-leaves 
       (λ (pair) (maybe-protect result (decode (car pair) (cdr pair))))
       (ffi:result-matrix result n-rows n-columns)))
    
    (define/public (param-info)
      ; for describing prepared statements
      (build-vector (ffi:n-params result) (λ (column) (ffi:param-type result column))))
    
    (field [fields (ffi:result-fields result)])))

(define cursor%
  (class object%
    (inspect #f)
    (init-field connection sql params [scroll? #f])
    (super-new)
    (field [name (make-cursor-name connection)])
;    (display (format "Name is ~s~n" name))
    (send/apply connection p-exec (string-append "DECLARE " (if scroll? "SCROLL " "") name " CURSOR FOR " sql) params)

    (define (inner-fetch-one [type "NEXT"])
      (let ([result (send connection p-exec (string-join (list "FETCH" type "FROM" name) " "))])
        result))
    
    (define/public (fetch-one type)
      (let ([result (inner-fetch-one type)])
        (if (not (= (get-field n-rows result) 1)) (raise-user-error "No more rows to fetch. use for-each dumbass")
            result)))
    
    (define/public (fetch-next)
      (fetch-one "NEXT"))
    (define/public (fetch-prior)
      (fetch-one "PRIOR"))
    
    (define/public (fetch-many n)
      (send connection p-exec (format "FETCH ~a ~s" n name)))
    
    (define/public (fetch-rest)
      (send connection p-exec (string-append "FETCH ALL " name)))
    
    ; this is probably the sanest one to use
    ; memory kind and logically separated by rows
    ; not by results
    (define/public (for-each row-handler)
      (let loop ([result (inner-fetch-one)])
        (if (= (get-field n-rows result) 0) (void)
            (begin
              (send result for-each row-handler)
              (loop (inner-fetch-one))))))
    
    (define/public (fold proc init)
      (let loop ((result init))
        (let ((current (inner-fetch-one)))
          (if (= (get-field n-rows current) 0) result
              (loop (send current fold proc result))))))))

(define prepared-statement%
  (class object%
    (inspect #f)
    (init-field connection name sql cast-able?)
    (super-new)
    
    (send connection inner-prepare name sql)
    (define oids (send (send connection describe name) param-info))
    
;    (display (format "Parameter oids for ~s are ~s~n" sql oids))
    
    (define (check-params params)
      (let loop ([params params] [column 0])
        (if (null? params) (void)
            (begin
              (cond
                [(null? params) 'ok]
                [(> column (vector-length oids)) (error "Not enough parameters provided to this statement!")]
                [else
                 (let ([param-oid (send connection divine (car params))])
                   (cond
                     [(eq? param-oid (vector-ref oids column)) 'ok]
                     [(cast-able? param-oid (vector-ref oids column)) 'ok]
                     [else (error (format "The OID for ~a(~a) could not match template OID ~a~n" (car params) param-oid 
                                          (vector-ref oids column)))])
                   (loop (cdr params) (+ column 1)))])))))
    
    
    (define/public (select . params)
      (new cursor% [connection connection] [sql sql] [params params]))
    
    (define/public (exec . params)
      (check-params params)
      (send connection inner-query-prepared name oids params))))

(define-syntax this-or
  (syntax-rules ()
    [(_ a b ...) (if a a (begin b ...))]))

(define-syntax define-send
  (syntax-rules ()
    [(_ method type object)
     (define method
       (let ([generic-method (generic type method)])
         (λ args
           (send-generic object generic-method . args))))]))

(define (my-vl-map proc v l)
  (let ([result (make-vector (vector-length v) 0)])
    (let loop ([index 0] [l l])
      (cond
        [(= index (vector-length v)) result]
        [(null? l) (error "List must have same number of members as vector.")]
        [else (begin
                (vector-set! result index (proc (vector-ref v index) (car l)))
                (loop (+ index 1) (cdr l)))]))))
 
(define connection%
  (class convert:base-engine%
    (inspect #f)
    (init-field handle input output)
    (inherit-field oid-size)
    (inherit set-vector-info!
             encode decode divine
             element-oid-for)
    
    (define query-serializer (make-semaphore 1))
    
    (super-new [integer-time (equal? "on" (ffi:connection-parameter handle "integer_datetimes"))])
    
    (field [cast-table 
            (make-immutable-hash 
             ; some sane defaults though initialize really is necessary to be sure they aren't wrong!
             '(((21 23) . #t)
               ((23 20) . #t)
               ((20 1700) . #t)
               ((21 1700) . #t)
               ((23 1700) . #t)
               ((20 700) . #t)
               ((21 700) . #t)
               ((23 700) . #t)
               ((21 26) . #t)
               ((23 26) . #t)
               ((19 25) . #t)
               ((17 25) . #t)
               ((25 17) . #t)
               ((25 1043) . #t)
               ((701 700) . #t)))]
           [oid-names 
            (make-immutable-hash 
             '((17   . "bytea")
               (19   . "name")
               (20   . "int8")
               (21   . "int2")
               (23   . "int4")
               (26   . "oid")
               (1700 . "numeric")))]
           [oid-sizes
            (make-immutable-hash
             '((20 . 8)
               (21 . 2)
               (23 . 4)
               (26 . 4)))])

    (define (cast-able? from to)
;      (display (format "cast ~s to ~s~n" from to))
      (hash-ref cast-table (list from to)
                (λ () (let ((from (element-oid-for from (λ () #f)))
                            (to (element-oid-for to (λ () #f))))
                        (if (and from to)
                            (cast-able? from to)
                            #f)))))

    (define prepare-cache (make-immutable-hash null))
    
    (when (not (ffi:set-client-encoding! handle "utf-8"))
      (error "We only roll in utf-8 dudes."))
    
    (define version (ffi:protocol-version handle))
    (when (< version 3) (error "Please upgrade!"))
 
    (define/public (trace name)
      (ffi:do-trace handle name))
    
    (define/public (initialize)
      ; yay, introspection and side effects galore!
      (set! 
       cast-table
       (make-immutable-hash
        (let ([result null])
          (send (exec "SELECT castsource,casttarget,castfunc FROM pg_cast") for-each
                (λ (source target func)
                  (when (not func) (error "Er, if not func cast-able? fails. wtf is this?"))
                  (set! result (cons (cons (list source target) func) result))))
          result)))
      
      (send (exec "SELECT oid,typname,typelem,typlen FROM pg_type") for-each
            (λ (oid name element-oid size)
              (set! oid-names (hash-set oid-names name oid))
              (set! oid-sizes (hash-set oid-sizes name size))
              (when (> element-oid 0)
                (let ([size (if (> size 0) size #f)])
                  (set-vector-info! oid element-oid size)))))
              
      
      (set! oid-size (hash-ref oid-sizes 'oid))
      
      ; now we have meaningful OID names, we can handle vectors of types (incl fixed)
      ; and we can tell whether the OID of a parameter would cast to the OID of its
      ; designated type.
     )
    
    (define (output-wait)
      (assert-status "flushing output" handle)
      (when (ffi:flushing? handle)
        (sync output)
        (output-wait)))
    
    (define (input-wait)
      (assert-status "polling input" handle)
      (sync/timeout 0 input)
      (when (not (ffi:consume-input handle))
        (pq-error handle "Could not consume input"))
      (if (ffi:is-busy? handle) (input-wait)
          ; only ever call start-result when not is-busy
          (begin0
            (ffi:start-result handle)
            (assert-status "polling input" handle))))
    
;    (define (get-result)
;      (assert-status "get result" handle)
;      (apply
;       values ; but there will only ever be one value because we don't use lameo PQexec/PQsendQuery.
;       (let loop ([result (input-wait)] [results null])
;         (if result (loop (input-wait) (cons (new result% [result result]) results))
;             (reverse results)))))
    
    (define last-sql #f)
     
    (define (get-result)
      (assert-status "get result" handle)
      (let ([result (input-wait)])
        (begin0
          (call-with-exception-handler
           (λ (e) (exec "ROLLBACK") e)
           (λ () (new result% [sql last-sql] [result result] [decode (λ (oid param) (decode oid param))])))
          (when (input-wait) (error "Got two results somehow?")))))
    
    (define (check-send proc . rest)
      (dynamic-wind
       (λ ()
         (semaphore-wait query-serializer))
       (λ () 
         (let ([ok (apply proc rest)])
           (when (not ok)
             (pq-error "We could not send a query!" handle))
           (output-wait)
           (get-result)))
       (λ () 
         (semaphore-post query-serializer))))
    
    (define/public (escape-identifier identifier)
      (ffi:escape-identifier handle identifier))
    
    (define/public (select sql . params)
      (new cursor% [connection this] [sql sql] [params params]))
    
    (define/public (exec sql . params)
      (set! last-sql sql)
      (let ([oids (map (λ (param) (divine param)) params)])
        (check-send ffi:send-query-params handle sql oids (map (λ (oid param) (encode oid param)) oids params))))
    
    (define/public (p-exec sql . params)
      (let ([prep (prepare sql)])
        (send/apply prep exec params)))
    
    (define/public (describe name)
      (check-send ffi:send-describe-prepared handle name))
    
    (define/public (prepare sql)
      (let ([name (prepare-hash sql)])
        (hash-ref 
         prepare-cache name
         (λ ()
           (let ([prepared (new prepared-statement% 
                                [connection this]
                                [name name]
                                [sql sql]
                                [cast-able? cast-able?])])
                                
             (set! prepare-cache (hash-set prepare-cache name prepared))
             prepared)))))
    
    (define/public (inner-prepare name sql)
      (set! last-sql sql)
      (check-send ffi:send-prepare handle name sql))
    
    (define transaction-level 0)
    (define/public (with-transaction body)
      (when (= transaction-level 0)
        (p-exec "BEGIN"))
      (set! transaction-level (+ transaction-level 1))
      (dynamic-wind
       void
       body
       (λ () 
         (when (= transaction-level 1)
           (p-exec "COMMIT"))
         (set! transaction-level (- transaction-level 1)))))
    
    
    (define/public (inner-query-prepared name oids params)
      (check-send ffi:send-query-prepared handle name (my-vl-map (λ (oid param) (encode oid param)) oids params)))))

(define-syntax with-transaction
  (syntax-rules ()
    [(_ connection commands ...) (send connection with-transaction (λ () commands ...))]))

(define (with-transaction-p connection thunk)
  (let ([intrans #f])
    (call-with-exception-handler
     (λ (e) (when intrans (send connection p-exec "ROLLBACK")) e)
     (λ ()
       (send connection p-exec "BEGIN")
       (set! intrans #t)
       (begin0
         (thunk)
         (send connection p-exec "COMMIT"))))))

; After sending any command or data on a nonblocking connection, call PQflush. If it returns 1, wait for the socket to be write-ready and call it again; repeat until it returns 0. Once PQflush returns 0, wait for the socket to be read-ready and then read the response as described above.

      ;this is redundant...
(define (output-wait handle output)
  (when (ffi:flushing? handle)
      (begin
        (sync output)
        (output-wait handle output))))

(define (pq-error message handle [sql #f])
  (error (format "~a~n~a [~s]~n" sql message (ffi:error-message handle))))

(define (assert-status message handle [sql #f])
  (when (eq? (ffi:connection-status handle) 'bad) (pq-error message handle)))

(define connect
  (make-keyword-procedure
   (λ (names values)
     (let ([handle (ffi:do-connect-start (ffi:make-parameters names values))])
       (assert-status "Connection has failed" handle)
       (ffi:set-nonblocking! handle #t)
       (ffi:set-error-verbosity! handle 'verbose)
       (let-values ([(input output) (ffi:socket-to-ports (ffi:connection-socket handle) "postgresql")])
         (sync output)
         (let loop ()
           (let ([val (ffi:connect-poll handle)])
             (case val
               [(reading) (begin (sync input) (loop))]
               [(writing) (begin (sync output) (loop))]
               [(failed) (pq-error "Connection failed while polling" handle)]
               [(ok) (new connection% [handle handle] [input input] [output output])]
               [else (error (format "Uh...huh? ~a~n" val))]))))))))

(provide connect with-transaction)