#lang scheme
(require (prefix-in ffi: "ffi.ss"))
(require (prefix-in convert: "converters/engine.ss"))
(require (only-in (planet vyzo/crypto) sha256))
(require net/base64)
(define prepare-hash
(λ (sql) (string-append "prep" (bytes->string/utf-8 (base64-encode (sha256 (string->bytes/utf-8 sql)))))))
(define make-cursor-name
(let ([next-id 0])
(λ (c)
(begin0
(send c escape-identifier (format "cursor~a" next-id))
(set! next-id (+ 1 next-id))))))
(define (map-leaves proc i)
(if (list? i)
(map (λ (ii) (map-leaves proc ii)) i)
(proc i)))
(define (maybe-protect result value)
(if (bytes? value) (ffi:protect-with! value result)
value))
(define result%
(class object%
(inspect #f)
(init-field result decode)
(field [status (ffi:result-status result)]
[n-rows (ffi:result-rows result)]
[n-columns (ffi:result-columns result)])
(define (warn message)
(let ([message (format "~a ~a~n~a~n" message (ffi:result-status->string status) (ffi:result-error-message result))])
(display message)
message))
(define (die message)
(raise-user-error (warn message)))
(case status
[(ok tuples-ok) (void)]
[(empty-query bad-response error) (die "result failed to init")]
[(warning) (warn "init warning")])
(super-new)
(define/public (for-each row-handler)
(ffi:result-for-each-row
(λ (row)
(apply
row-handler
(map (λ (value) (maybe-protect result value))
(dict-map
row
decode))))
result
n-rows
n-columns))
(define/public (get-matrix)
(map-leaves
(λ (pair) (maybe-protect result (decode (car pair) (cdr pair))))
(ffi:result-matrix result n-rows n-columns)))
(field [fields (ffi:result-fields result)])))
(define cursor%
(class object%
(inspect #f)
(init-field connection sql params)
(super-new)
(field [name (make-cursor-name connection)])
(send/apply connection p-exec (string-append "DECLARE " name " CURSOR FOR " sql) params)
(define (inner-fetch-one)
(let ([result (send connection p-exec (string-append "FETCH " name))])
result))
(define/public (fetch-one)
(let ([result (inner-fetch-one)])
(if (not (= (get-field n-rows result) 1)) (raise-user-error "No more rows to fetch. use for-each dumbass")
result)))
(define/public (fetch-many n)
(send connection p-exec (format "FETCH ~a ~s" n name)))
(define/public (fetch-rest)
(send connection p-exec (string-append "FETCH ALL " name)))
(define/public (for-each row-handler)
(let loop ([result (inner-fetch-one)])
(if (= (get-field n-rows result) 0) (void)
(begin
(send result for-each row-handler)
(loop (inner-fetch-one))))))))
(define prepared-statement%
(class object%
(inspect #f)
(init-field connection name sql template cast-able?)
(super-new)
(define info (send connection inner-prepare name sql template))
(define (check-params params)
(let loop ([params params] [oids template])
(if (null? params) (void)
(begin
(cond
[(null? params) 'ok]
[(null? oids) (error "Not enough parameters provided to this statement!")]
[else
(let ([param-oid (send connection divine (car params))])
(cond
[(eq? param-oid (car oids)) 'ok]
[(cast-able? param-oid (car oids)) 'ok]
[else (error (format "The OID for ~a(~a) could not match template OID ~a~n" (car params) param-oid (car oids)))])
(loop (cdr params) (cdr oids)))])))))
(define/public (exec . params)
(check-params params)
(send connection inner-query-prepared name template params))))
(define-syntax this-or
(syntax-rules ()
[(_ a b ...) (if a a (begin b ...))]))
(define-syntax define-send
(syntax-rules ()
[(_ method type object)
(define method
(let ([generic-method (generic type method)])
(λ args
(send-generic object generic-method . args))))]))
(define connection%
(class convert:base-engine%
(inspect #f)
(init-field handle input output)
(inherit-field oid-size)
(inherit set-vector-info!
encode decode divine)
(super-new [integer-time (equal? "on" (ffi:connection-parameter handle "integer_datetimes"))])
(field [cast-table (make-immutable-hash null)]
[oid-names (make-immutable-hash null)])
(define (cast-able? from to)
(hash-ref cast-table (cons from to) (λ () #f)))
(define prepare-cache (make-immutable-hash null))
(when (not (ffi:set-client-encoding! handle "utf-8"))
(error "We only roll in utf-8 dudes."))
(define version (ffi:protocol-version handle))
(when (< version 3) (error "Please upgrade!"))
(define/public (initialize)
(send (exec "SELECT castsource,casttarget,castfunc FROM pg_cast") for-each
(λ (source target func)
(when (not func) (error "Er, if not func cast-able? fails. wtf is this?"))
(set! cast-table (hash-set cast-table (cons source target) func))))
(send (exec "SELECT oid,typname FROM pg_type") for-each
(λ (oid name)
(set! oid-names (hash-set oid-names name oid))))
(set! oid-size (hash-ref oid-names 'oid))
(send (exec "SELECT oid,typelem,typlen FROM pg_type") for-each
(λ (oid element-oid length)
(when (> element-oid 0)
(let ([length (if (> length 0) length #f)])
(set-vector-info! oid element-oid length)))))
)
(define (output-wait)
(assert-status "flushing output" handle)
(when (ffi:flushing? handle)
(sync output)
(output-wait)))
(define (input-wait)
(assert-status "polling input" handle)
(sync/timeout 0 input)
(check-send (ffi:consume-input handle))
(if (ffi:is-busy? handle) (input-wait)
(begin0
(ffi:start-result handle)
(assert-status "polling input" handle))))
(define (get-result)
(assert-status "get result" handle)
(let ([result (input-wait)])
(begin0
(new result% [result result] [decode (λ (oid param) (decode oid param))])
(when (input-wait) (error "Got two results somehow?")))))
(define (check-send ok)
(if ok (void)
(pq-error "We could not send a query!" handle)))
(define/public (escape-identifier identifier)
(ffi:escape-identifier handle identifier))
(define/public (select sql . params)
(new cursor% [connection this] [sql sql] [params params]))
(define/public (exec sql . params)
(let ([oids (map (λ (param) (divine param)) params)])
(check-send (ffi:send-query-params handle sql oids (map (λ (oid param) (encode oid param)) oids params))))
(output-wait)
(get-result))
(define/public (p-exec sql . params)
(let ([prep (prepare sql params)])
(send/apply prep exec params)))
(define/public (prepare sql parameter-template)
(let ([name (prepare-hash sql)])
(hash-ref
prepare-cache name
(λ ()
(let ([prepared (new prepared-statement%
[connection this]
[name name]
[sql sql]
[cast-able? cast-able?]
[template (map (λ (param) (divine param)) parameter-template)])])
(set! prepare-cache (hash-set prepare-cache name prepared))
prepared)))))
(define/public (inner-prepare name sql oids)
(check-send (ffi:send-prepare handle name sql oids))
(output-wait)
(get-result))
(define/public (inner-query-prepared name oids params)
(check-send (ffi:send-query-prepared handle name oids (map (λ (oid param) (encode oid param)) oids params)))
(output-wait)
(get-result))
))
(define-syntax with-transaction
(syntax-rules ()
[(_ connection commands ...)
(dynamic-wind
(λ () (send connection p-exec "BEGIN"))
(λ () commands ...)
(λ () (send connection p-exec "COMMIT")))]))
(define (output-wait handle output)
(when (ffi:flushing? handle)
(begin
(sync output)
(output-wait handle output))))
(define (pq-error message handle)
(error (format "~a [~s]~n" message (ffi:error-message handle))))
(define (assert-status message handle)
(when (eq? (ffi:connection-status handle) 'bad) (pq-error message handle)))
(define connect
(make-keyword-procedure
(λ (names values)
(let ([handle (ffi:do-connect-start (ffi:make-parameters names values))])
(assert-status "Connection has failed" handle)
(ffi:set-nonblocking! handle #t)
(ffi:set-error-verbosity! handle 'verbose)
(let-values ([(input output) (ffi:socket-to-ports (ffi:connection-socket handle) "postgresql")])
(sync output)
(let loop ()
(let ([val (ffi:connect-poll handle)])
(case val
[(reading) (begin (sync input) (loop))]
[(writing) (begin (sync output) (loop))]
[(failed) (pq-error "Connection failed while polling" handle)]
[(ok) (new connection% [handle handle] [input input] [output output])]
[else (error (format "Uh...huh? ~a~n" val))]))))))))
(provide connect with-transaction)
(provide
(rename-out (ffi:with-tracing-to with-tracing-to))
)