#lang racket/base
(require racket/class
racket/match
racket/vector
file/md5
openssl/mzssl
"../generic/interfaces.rkt"
"../generic/sql-data.rkt"
"../generic/prepared.rkt"
"../generic/exceptions.rkt"
"msg.rkt"
"dbsystem.rkt")
(provide connection%)
(define DEBUG-RESPONSES #f)
(define DEBUG-SENT-MESSAGES #f)
(define connection-base%
(class* object% (connection<%> connector<%>)
(init-private notice-handler
notification-handler
allow-cleartext-password?)
(define inport #f)
(define outport #f)
(define process-id #f)
(define secret-key #f)
(super-new)
(define-syntax-rule (with-disconnect-on-error . body)
(with-handlers ([exn:fail? (lambda (e) (disconnect* #f) (raise e))])
. body))
(define wlock (make-semaphore 1))
(define delayed-handler-calls null)
(define/private (lock who require-connected?)
(semaphore-wait wlock)
(when (and require-connected? (not outport))
(semaphore-post wlock)
(error who "not connected")))
(define/private (unlock)
(let ([handler-calls delayed-handler-calls])
(set! delayed-handler-calls null)
(semaphore-post wlock)
(for-each (lambda (p) (p)) handler-calls)))
(define/private (call-with-lock who proc
#:require-connected? [require-connected? #t])
(lock who require-connected?)
(with-handlers ([values (lambda (e) (unlock) (raise e))])
(begin0 (proc) (unlock))))
(define/private (raw-recv)
(with-disconnect-on-error
(let ([r (parse-server-message inport)])
(when DEBUG-RESPONSES
(fprintf (current-error-port) " << ~s\n" r))
r)))
(define/private (recv-message fsym)
(let ([r (raw-recv)])
(cond [(ErrorResponse? r)
(check-ready-for-query fsym #t) (raise-backend-error fsym r)]
[(or (NoticeResponse? r)
(NotificationResponse? r)
(ParameterStatus? r))
(handle-async-message fsym r)
(recv-message fsym)]
[else r])))
(define/private (send-message msg)
(buffer-message msg)
(flush-message-buffer))
(define/private (buffer-message msg)
(when DEBUG-SENT-MESSAGES
(fprintf (current-error-port) " >> ~s\n" msg))
(with-disconnect-on-error
(write-message msg outport)))
(define/private (flush-message-buffer)
(with-disconnect-on-error
(flush-output outport)))
(define/private (check-ready-for-query fsym or-eof?)
(let ([r (recv-message fsym)])
(cond [(ReadyForQuery? r) (void)]
[(and or-eof? (eof-object? r)) (void)]
[else
(error fsym "internal error: backend sent unexpected message")])))
(define/private (handle-async-message fsym msg)
(match msg
[(struct NoticeResponse (properties))
(set! delayed-handler-calls
(cons (lambda ()
(notice-handler (cdr (assq 'code properties))
(cdr (assq 'message properties))))
delayed-handler-calls))]
[(struct NotificationResponse (pid condition info))
(set! delayed-handler-calls
(cons (lambda ()
(notification-handler condition))
delayed-handler-calls))]
[(struct ParameterStatus (name value))
(cond [(equal? name "client_encoding")
(unless (equal? value "UTF8")
(disconnect* #f)
(error fsym
(string-append
"backend attempted to change the client character encoding "
"from UTF8 to ~a, disconnecting")
value))]
[else (void)])]))
(define/public (disconnect)
(disconnect* #t))
(define/private (disconnect* no-lock-held?)
(define politely? no-lock-held?)
(define (go)
(when DEBUG-SENT-MESSAGES
(fprintf (current-error-port) " ** Disconnecting\n"))
(when outport
(when politely?
(send-message (make-Terminate)))
(close-output-port outport)
(set! outport #f))
(when inport
(close-input-port inport)
(set! inport #f)))
(cond [politely?
(call-with-lock 'disconnect go
#:require-connected? #f)]
[else (go)]))
(define/public (connected?)
(let ([outport outport])
(and outport (not (port-closed? outport)))))
(define/public (get-dbsystem)
dbsystem)
(define/public (attach-to-ports in out)
(set! inport in)
(set! outport out))
(define/public (start-connection-protocol dbname username password)
(with-disconnect-on-error
(call-with-lock 'postgresql-connect
(lambda ()
(send-message
(make-StartupMessage
(list (cons "user" username)
(cons "database" dbname)
(cons "client_encoding" "UTF8")
(cons "DateStyle" "ISO, MDY"))))
(connect:expect-auth username password)))))
(define/private (connect:expect-auth username password)
(let ([r (recv-message 'postgresql-connect)])
(match r
[(struct AuthenticationOk ())
(connect:expect-ready-for-query)]
[(struct AuthenticationCleartextPassword ())
(unless (string? password)
(error 'postgresql-connect "password needed but not supplied"))
(unless allow-cleartext-password?
(error 'postgresql-connect (nosupport "cleartext password")))
(send-message (make-PasswordMessage password))
(connect:expect-auth username password)]
[(struct AuthenticationCryptPassword (salt))
(unless #f (error 'postgresql-connect (nosupport "crypt()-encrypted password")))
(unless (string? password)
(error 'postgresql-connect "password needed but not supplied"))
(send-message (make-PasswordMessage (crypt-password password salt)))
(connect:expect-auth username password)]
[(struct AuthenticationMD5Password (salt))
(unless (string? password)
(error 'postgresql-connect "password needed but not supplied"))
(send-message (make-PasswordMessage (md5-password username password salt)))
(connect:expect-auth username password)]
[(struct AuthenticationKerberosV5 ())
(error 'postgresql-connect (nosupport "KerberosV5 authentication"))]
[(struct AuthenticationSCMCredential ())
(error 'postgresql-connect (nosupport "SCM authentication"))]
[_
(error 'postgresql-connect
"internal error: unknown message during authentication")])))
(define/private (connect:expect-ready-for-query)
(let ([r (recv-message 'postgresql-connect)])
(match r
[(struct ReadyForQuery (status))
(void)]
[(struct BackendKeyData (pid secret))
(set! process-id pid)
(set! secret-key secret)
(connect:expect-ready-for-query)]
[_
(error 'postgresql-connect
"internal error: unknown message after authentication")])))
(define/public (query fsym stmt collector)
(check-statement fsym stmt)
(let ([result
(call-with-lock fsym
(lambda ()
(query1:enqueue stmt)
(send-message (make-Sync))
(begin0 (query1:collect fsym stmt)
(check-ready-for-query fsym #f))))])
(statement:after-exec stmt)
(query1:process-result fsym collector result)))
(define/private (query1:enqueue stmt)
(if (string? stmt)
(begin (buffer-message (make-Parse "" stmt null))
(buffer-message (make-Bind "" "" null null null)))
(let* ([pst (statement-binding-pst stmt)]
[pst-name (send pst get-handle)]
[params (statement-binding-params stmt)])
(buffer-message (make-Bind "" pst-name null params null))))
(buffer-message (make-Describe 'portal ""))
(buffer-message (make-Execute "" 0))
(buffer-message (make-Close 'portal "")))
(define/private (query1:collect fsym stmt)
(when (string? stmt)
(match (recv-message fsym)
[(struct ParseComplete ()) (void)]
[other-r (query1:error fsym other-r)]))
(match (recv-message fsym)
[(struct BindComplete ()) (void)]
[other-r (query1:error fsym other-r)])
(match (recv-message fsym)
[(struct RowDescription (field-dvecs))
(let* ([rows (query1:data-loop fsym)])
(query1:expect-close-complete fsym)
(vector 'recordset field-dvecs rows))]
[(struct NoData ())
(let* ([command (query1:expect-completion fsym)])
(query1:expect-close-complete fsym)
(vector 'command command))]
[other-r (query1:error fsym other-r)]))
(define/private (query1:data-loop fsym)
(match (recv-message fsym)
[(struct DataRow (row))
(cons (list->vector row) (query1:data-loop fsym))]
[(struct CommandComplete (command)) null]
[other-r (query1:error fsym other-r)]))
(define/private (query1:expect-completion fsym)
(match (recv-message fsym)
[(struct CommandComplete (command)) `((command . ,command))]
[(struct EmptyQueryResponse ()) '()]
[other-r (query1:error fsym other-r)]))
(define/private (query1:expect-close-complete fsym)
(match (recv-message fsym)
[(struct CloseComplete ()) (void)]
[other-r (query1:error fsym other-r)]))
(define/private (query1:error fsym r)
(match r
[(struct CopyInResponse (format column-formats))
(error fsym (nosupport "COPY IN statements"))]
[(struct CopyOutResponse (format column-formats))
(error fsym (nosupport "COPY OUT statements"))]
[_ (error fsym "internal error: unexpected message")]))
(define/private (query1:process-result fsym collector result)
(match result
[(vector 'recordset field-dvecs rows)
(let-values ([(init combine finalize headers?)
(collector (length field-dvecs) #t)])
(let* ([type-reader-v
(list->vector (query1:get-type-readers fsym field-dvecs))]
[row-length (length field-dvecs)]
[convert-row
(lambda (row)
(vector-map! (lambda (value type-reader)
(cond [(sql-null? value) sql-null]
[type-reader (type-reader value)]
[else value]))
row
type-reader-v))])
(recordset (and headers?
(map field-dvec->field-info field-dvecs))
(finalize
(for/fold ([accum init]) ([row (in-list rows)])
(combine accum (convert-row row)))))))]
[(vector 'command command)
(simple-result command)]))
(define/private (query1:get-type-readers fsym field-dvecs)
(map (lambda (dvec)
(let ([typeid (field-dvec->typeid dvec)])
(typeid->type-reader fsym typeid)))
field-dvecs))
(define/public (prepare fsym stmt close-on-exec?)
(call-with-lock fsym
(lambda ()
(let ([name (generate-name)])
(prepare1:enqueue name stmt)
(send-message (make-Sync))
(begin0 (prepare1:collect fsym name close-on-exec?)
(check-ready-for-query fsym #f))))))
(define/private (prepare1:enqueue name stmt)
(buffer-message (make-Parse name stmt null))
(buffer-message (make-Describe 'statement name)))
(define/private (prepare1:collect fsym name close-on-exec?)
(match (recv-message fsym)
[(struct ParseComplete ()) (void)]
[other-r (prepare1:error fsym other-r)])
(let* ([param-typeids (prepare1:describe-params fsym)]
[field-dvecs (prepare1:describe-result fsym)])
(new prepared-statement%
(handle name)
(close-on-exec? close-on-exec?)
(param-typeids param-typeids)
(result-dvecs field-dvecs)
(owner this))))
(define/private (prepare1:describe-params fsym)
(match (recv-message fsym)
[(struct ParameterDescription (param-typeids)) param-typeids]
[other-r (prepare1:error fsym other-r)]))
(define/private (prepare1:describe-result fsym)
(match (recv-message fsym)
[(struct RowDescription (field-dvecs)) field-dvecs]
[(struct NoData ()) null]
[other-r (prepare1:error fsym other-r)]))
(define/private (prepare1:error fsym r)
(error fsym "internal error: unexpected message in prepare"))
(define/private (check-statement fsym stmt)
(unless (or (string? stmt) (statement-binding? stmt))
(raise-type-error fsym "string or statement-binding" stmt))
(when (statement-binding? stmt)
(let ([pst (statement-binding-pst stmt)])
(send pst check-owner fsym this stmt))))
(define name-counter 0)
(define/private (generate-name)
(let ([n name-counter])
(set! name-counter (add1 name-counter))
(format "λmz_~a_~a" process-id n)))
(define/public (free-statement pst)
(call-with-lock 'free-statement
#:require-connected? #f
(lambda ()
(let ([name (send pst get-handle)])
(when (and name outport) (send pst set-handle #f)
(buffer-message (make-Close 'statement name))
(buffer-message (make-Sync))
(let ([r (recv-message 'free-statement)])
(cond [(CloseComplete? r) (void)]
[else (error 'free-statement "internal error: unexpected message")])
(check-ready-for-query 'free-statement #t)))))))
))
(define ssl-connector-mixin
(mixin (connector<%>) ()
(super-new)
(define/override (attach-to-ports in out [ssl 'no] [ssl-encrypt #f])
(with-handlers ([(lambda _ #t)
(lambda (e)
(close-input-port in)
(close-output-port out)
(raise e))])
(case ssl
((yes optional)
(write-message (make-SSLRequest) out)
(flush-output out)
(let ([response (peek-byte in)])
(case (integer->char response)
((#\S)
(void (read-byte in))
(let-values ([(sin sout)
(ports->ssl-ports in out
#:mode 'connect
#:encrypt ssl-encrypt
#:close-original? #t)])
(super attach-to-ports sin sout)))
((#\N)
(void (read-byte in))
(unless (eq? ssl 'optional)
(error 'postgresql-connect "backend refused SSL connection"))
(super attach-to-ports in out))
((#\E)
(let ([r (parse-server-message in)])
(raise-backend-error 'postgresql-connect r)))
(else
(error 'postgresql-connect
"backend returned invalid response to SSL request")))))
((no)
(super attach-to-ports in out)))))))
(define (nosupport str)
(string-append "not supported: " str))
(define (md5-password user password salt)
(bytes->string/latin-1
(md5-password/bytes (string->bytes/latin-1 user)
(string->bytes/latin-1 password)
salt)))
(define (md5-password/bytes user password salt)
(let* ([s (md5 (bytes-append password user))]
[t (md5 (bytes-append s salt))])
(bytes-append #"md5" t)))
(define (crypt-password password salt)
(error 'crypt-password "not implemented"))
(define (raise-backend-error who r)
(define props (ErrorResponse-properties r))
(define code (cdr (assq 'code props)))
(define message (cdr (assq 'message props)))
(raise-sql-error who code message props))
(define connection%
(class (ssl-connector-mixin connection-base%)
(super-new)))