ffi.ss
#lang scheme

(require (planet synx/pointer-address))
(require (prefix-in log: (planet synx/log)))

(require scheme/foreign)
(require srfi/43)

(unsafe!)

(define *scheme* (ffi-lib #f))

(define *lib* (ffi-lib "libpq"))

; many C functions return 0 if okay, and non-zero *cough*-1*cough* on error.
(define _antibool (make-ctype _int (λ (scheme) (if scheme 0 -1)) (λ (c) (if (= c 0) #t #f))))

;(define (guard finalizer name address)
;  ; only allow finalization once!
;  (display (format "Guarding ~a ~s~n" name address))
;  (let ([applied #f])
;    (λ (value)
;      (if applied (error (format "Attempted to apply ~s twice" finalizer))
;          (begin
;            (set! applied #t)
;            (when (not (= (get-pointer-address value) address))
;              (error (format "Uh oh the value changed before finalizing? ~s != ~s" address (get-pointer-address value))))
;            (display (format "Finalizing ~a ~s~n" name address))
;            (finalizer value))))))

(define guard (λ (f n a) f))

(define socket-to-ports (get-ffi-obj "scheme_socket_to_ports" *scheme* 
                                       (_fun (socket name) 
                                             ::
                                             (socket : _int)
                                             (name : _scheme)
                                             (close : _bool = #f)
                                             (in-port : (_ptr o _scheme))
                                             (out-port : (_ptr o _scheme))
                                             -> _void
                                             -> (values in-port out-port))))

(define-cpointer-type _cfile)

; goddamit Freehaven...

(define open-c-file
  (get-ffi-obj "fopen" *scheme*
               (_fun (name mode)
                     ::
                     (name : _path)
                     (mode : _string)
                     -> _cfile)))

; we're just passing it to PQtrace and praying
(define close-c-file
  (get-ffi-obj "fclose" *scheme*
               (_fun _cfile -> _void)))

(define-cpointer-type _connection)
(define-cpointer-type _result)
;(define _result _int64)
;(define _result/null _int64)
;(define result? integer?)
(define-cpointer-type _cancel)

(define (pq-name-format symbol)
  (let ([pieces (regexp-split #rx"-" (symbol->string symbol))])
    (apply string-append "PQ" (car pieces) (map string-titlecase (cdr pieces)))))

(define-syntax define-pq
  (syntax-rules ()
    [(_ name type) (define-pq name name type)]
    [(_ name uglyname type) (define name (get-ffi-obj (pq-name-format 'uglyname) *lib* type))]))

(define make-parameters
  (λ (names values)
     (foldl 
      (λ (pair head)
        (if (not head) pair
            (string-append head " " pair)))
      #f
      (map 
       (λ (n v) 
         (string-append 
          (keyword->string n) "=" 
          (cond
            [(string? v)
             (if (regexp-match #px"[\\s]" v)
                 (string-append "'" v "'")
                 v)]
            [(integer? v)
             (number->string v)]
            [else (error (format "Strange connection parameter ~a:~s~n" n v))])))
       names values))))

(define _status (_enum '(ok bad started made awaiting-response auth-ok setenv ssl-startup needed)))
(define _result-status (_enum '(empty-query ok tuples-ok copy-out copy-in bad-response warning error)))
(define _polling-status (_enum '(failed reading writing ok active)))
(define _transaction-status (_enum '(idle active in-trans in-error unknown)))
(define _verbosity (_enum '(terse default verbose)))

(define-pq connect-start (_fun (parameters : _string) -> _connection/null))
(define-pq connect-poll (_fun _connection -> _polling-status))
(define-pq finish (_fun _connection -> _void))

(define-pq connection-parameter parameter-status (_fun _connection _string -> _string))
(define-pq protocol-version (_fun _connection -> _int))

(define-pq set-client-encoding! set-client-encoding (_fun _connection _string -> _antibool))
(define-pq set-error-verbosity! set-error-verbosity (_fun _connection _verbosity -> _verbosity))

(define-pq escape-string-conn
  (_fun (connection identifier result length)
        ::
        (connection : _connection)
        (result : _pointer)
        (identifier : _string)
        (length : _int)
        (error : (_ptr o _int))
        -> (size : _uint32)
        -> (values size error)))

(define (escape-identifier connection identifier)
  (let* ([length (string-length identifier)]
         [buffer (malloc (* 2 length))])
    ; Note: (malloc) defaults to using COLLECTABLE memory
    ;(register-finalizer buffer (guard free (format "escape-identifier (~a)" identifier)))
    (let-values ([(size error) (escape-string-conn connection identifier buffer length)])
      (when (not (= error 0)) 
        (display (format "warning ~s may be malformed identifier~n" identifier)))
      ; only 1 copy
      (bytes->string/utf-8 (make-sized-byte-string buffer size)))))

(define-pq trace (_fun _connection _cfile -> _void))
(define-pq untrace (_fun _connection -> _void))

(define-syntax with-tracing-to
  (syntax-rules ()
    [(_ (connection name) commands ...) (with-tracing-proc connection name (λ () commands ...))]))
 
(define *cfile* #f)
(define (do-trace connection name)
  (when (not *cfile*)
    (set! *cfile* (open-c-file name "at"))
    (register-finalizer *cfile* (guard close-c-file (format "file ~s" name) (get-pointer-address *cfile*))))
  (trace connection *cfile*))
       
(define (with-tracing-proc connection name thunk)
  (let ([cfile (open-c-file name "at")])
    (register-finalizer cfile (guard close-c-file (format "file ~s" name) (get-pointer-address cfile)))
    (dynamic-wind
     (λ () (trace connection cfile))
     thunk
     (λ () (untrace connection)))))

(define-pq get-cancel (_fun _connection -> _cancel))
(define-pq ffi-cancel cancel (_fun _cancel _bytes _int -> _bool))
(define (cancel c)
  (let ([bytes (make-bytes 256)])
    (let ([result (ffi-cancel c bytes 256)])
      (values result (bytes->string/utf-8 bytes)))))                        
(define-pq free-cancel (_fun _cancel -> _void))

(define (do-connect-start parameters)
  (let ([handle (connect-start parameters)])
    (if handle
        (begin
          (register-finalizer handle 
                              (guard finish "connection" (get-pointer-address handle)))
          handle)
        (error "Could not start connecting"))))

(define-pq result-clear clear (_fun _result -> _void))
(define-pq connection-status status (_fun _connection -> _status))
(define-pq result-status (_fun _result -> _result-status))
(define-pq result-status->string res-status (_fun _result-status -> _string))
(define-pq result-error-message (_fun _result -> _string))

(define-pq n-params nparams (_fun _result -> _int))
(define-pq param-type paramtype (_fun _result _int -> _int))

(define-pq error-message (_fun _connection -> _string))
(define-pq connection-socket socket (_fun _connection -> _int))
; this function is f*cking useless
;(define-pq send-query (_fun _connection _string -> _bool))

(define-pq send-query-params
  (_fun (connection query oids values)
        ::
        (connection : _connection)
        (query : _string) 
        (nParams : _int = (length values))
        (oids : (_list i _int))
        (values : (_list i _bytes))
        (lengths : (_list i _int) = (map bytes-length values))
        ; we always want binary format (please don't make me write a SQL text format processor T_T)
        (formats : (_list i _int) = (map (λ (v) 1) values))
        (resultF : _int = 1)
        -> _bool))

; we might want to pre-specify oids in the future
; if we're updating a lot of integers to numeric columns for instance
; but they're all small enough to fit in an int2...
; for now just get it working -.-

(define-pq send-prepare 
  (_fun (connection name query) 
        ::
        (connection : _connection)
        (name : _string) 
        (query : _string) 
        (nParams : _int = 0)
        (oids : _pointer = #f)
        -> _bool))

(define-pq send-query-prepared 
  (_fun (connection name values)
        ::
        (connection : _connection)
        (name : _string)
        (nParams : _int = (vector-length values))
        ;(oids : (_list i _int) = (map pick-an-oid items)) boy I WISH
        (values : (_vector i _bytes))
        (lengths : (_vector i _int) = (vector-map (λ (i v) (bytes-length v)) values))
        ; we always want binary format (please don't make me write a SQL text format processor T_T)
        (formats : (_vector i _int) = (build-vector nParams (λ (i) 1)))
        (resultF : _int = 1)
        -> _bool))

(define-pq send-describe-prepared 
  (_fun (connection name)
        ::
        (connection : _connection)
        (name : _string)
        -> _bool))

(define-pq get-result (_fun _connection -> _result/null))
(define-pq make-empty-result makeEmptyPGresult (_fun _connection/null _result-status -> _result))

(define-pq consume-input (_fun _connection -> _bool))
(define-pq is-busy? is-busy (_fun _connection -> _bool))

; readable data on the file descriptor identified by PQsocket. When the main loop detects input ready, it should call PQconsumeInput to read the input. It can then call PQisBusy, followed by PQgetResult if PQisBusy returns false (0).

(define-pq set-nonblocking! setnonblocking (_fun _connection _bool -> _antibool))
(define-pq is-nonblocking? isnonblocking (_fun _connection -> _bool))
; this one's semantics are still 0 -> #f
(define-pq flushing? flush (_fun _connection -> _bool))

(define-pq null-cell? getisnull (_fun _result _int _int -> _bool))
(define-pq cell-name fname (_fun _result _int -> _string))
(define-pq cell-value getvalue (_fun _result _int _int -> _pointer))
(define-pq cell-type ftype (_fun _result _int -> _int))
(define-pq cell-length getlength (_fun _result _int _int -> _int))

(define-pq result-columns nfields (_fun _result -> _int))
(define-pq result-rows ntuples (_fun _result -> _int))

; note to be careful:
; NEVER use bytes out of this after the result
; object has gone out of scope, or been cleared.
; they are direct pointers into the result object
; to save on copying time.
; Beware! I am not encapsulating the result object, for a reason!

(define (start-result connection)
  (let ([result (get-result connection)])
;    (when (= result 0) (set! result #f))
    (when result
      (register-finalizer result (guard result-clear "result" (get-pointer-address result))))
    result))

; DOOM!

(define (build-result-cell result row column)
  (if (null-cell? result row column) (cons #f #f)
      (let ([value (cell-value result row column)]
            [length (cell-length result row column)]
            [oid (cell-type result column)])
        (cons oid (make-sized-byte-string value length)))))

(define (result-fields result)
  (build-list 
   (result-columns result)
   (λ (column)
     (cell-name result column))))

(define protected (make-weak-hash))
; unfortunately this can only be called /after/ non-ffi decoding
(define (protect-with! value result)
  (hash-set! protected value result)
  value)

; this one is probably the best to use, if possible.
(define (result-for-each-row row-handler result n-rows n-columns)
  (let next-row ([row 0])
    (if (= row n-rows) (void)
        (let next-column ([column 0] [cells null])
          (if (= column n-columns) 
              (begin
                (row-handler (reverse cells))
                (next-row (+ row 1)))
              (next-column 
               (+ column 1)
               (cons (build-result-cell result row column) cells)))))))

; ehhh or this?
(define (result-fold row-handler init sql-result n-rows n-columns)
  (let next-row ([row 0] [result init])
    (if (= row n-rows) result
        (let next-column ([column 0] [cells null])
          (if (= column n-columns)
              (next-row (+ row 1) (apply row-handler (list result (reverse cells))))
              (next-column 
               (+ column 1)
               (cons (build-result-cell sql-result row column) cells)))))))

(define (result-map row-handler result n-rows n-columns)
  (reverse
   (result-fold
    (λ args
      (let ((results (car args))
            (row (cdr args)))
        (cons (apply row-handler row) results)))
    null
    result
    n-rows
    n-columns)))

(define (result-matrix result n-rows n-columns)
  (build-list 
   n-rows
   (λ (row) 
     (build-list
      n-columns
      (λ (column)
        (build-result-cell result row column))))))

(require (prefix-in c (only-in scheme/contract -> ->*)))

(provide/contract
 (socket-to-ports (integer? string? . c-> . (values input-port? output-port?)))
 (make-parameters ((listof keyword?) (listof any/c) . c-> . string?))
 
 (get-cancel (connection? . c-> . cancel?))
 (free-cancel (cancel? . c-> . void?))
 (cancel (cancel? . c-> . string?))
 
 (do-connect-start (string? . c-> . connection?))
 (connect-poll (connection? . c-> . symbol?))
 (connection-status (connection? . c-> . symbol?))
 (error-message (connection? . c-> . string?))
 (connection-socket (connection? . c-> . integer?))
 
 (send-query-params (connection? string? (listof integer?) (listof bytes?) . c-> . boolean?))
 (send-prepare (connection? string? string? . c-> . boolean?))
 (send-query-prepared (connection? string? (vectorof bytes?) . c-> . boolean?))
 
 (send-describe-prepared (connection? string?  . c-> . boolean?))
 (consume-input (connection? . c-> . boolean?))
 (is-busy? (connection? . c-> . boolean?))
 (set-nonblocking! (connection? boolean? . c-> . boolean?))
 (is-nonblocking? (connection? . c-> . boolean?))
 (flushing? (connection? . c-> . boolean?))
 (start-result (connection? . c-> . (or/c result? false?)))
 (result-rows (result? . c-> . integer?))
 (result-columns (result? . c-> . integer?))
 
 (result-for-each-row (procedure? result? integer? integer? . c-> . void?))
 (result-fold (procedure? any/c result? integer? integer? . c-> . any/c))
 (result-map (procedure? result? integer? integer? . c-> . (listof any/c)))
 (result-matrix (result? integer? integer? . c-> . (listof list?)))
 
 (result-fields (result? . c-> . (listof string?)))
 (set-client-encoding! (connection? string? . c-> . boolean?))
 (set-error-verbosity! (connection? symbol? . c-> . symbol?))
 (result-status (result? . c-> . symbol?))
 (result-status->string (symbol? . c-> . string?))
 (result-error-message (result? . c-> . string?))
 (escape-identifier (connection? string? . c-> . string?))
 (connection-parameter (connection? string? . c-> . string?))
 (protocol-version (connection? . c-> . integer?))
 
 (n-params (result? . c-> . integer?))
 (param-type (result? integer? . c-> . integer?))
 
 (protect-with! (any/c result? . c-> . any/c))
 )

(provide with-tracing-to do-trace)