(module protocol2 mzscheme
(require (lib "etc.ss")
(lib "list.ss")
"bitbang.ss"
"protocol-structures.ss"
"exceptions.ss"
"sql-types.ss")
(provide protocol2:new
protocol2:reset
protocol2:close
protocol2:lock
protocol2:lock/key
protocol2:unlock
protocol2:encode
message-generator:current
message-generator:next
message-generator:current/next
message-generator:done?
protocol:lock:disconnected
protocol:lock:auth-required
protocol:lock:copy-in
protocol:lock:ready)
(define lock-counter 0)
(define protocol:lock:disconnected 'disconnected)
(define protocol:lock:auth-required 'auth-required)
(define protocol:lock:copy-in 'copy-in)
(define protocol:lock:ready 'ready)
(define-struct protocol2
(inport outport last-field-count message-generator status key))
(define-struct message-generator (protocol promise done?))
(define (end-of-exchange-message? msg)
(or (ReadyForQuery? msg)
(FatalErrorResponse? msg)))
(define (protocol2:new inport outport)
(make-protocol2 inport outport 0 #f protocol:lock:disconnected #f))
(define (protocol2:reset protocol)
(let [(mg (protocol2-message-generator protocol))
(inport (protocol2-inport protocol))]
(let loop [(mg mg)]
(when (and mg (not (message-generator:done? mg)))
(loop (message-generator:next mg))))
(let [(new-mg (message-generator:new protocol))]
(set-protocol2-message-generator! protocol new-mg)
new-mg)))
(define (protocol2:close protocol)
(close-output-port (protocol2-outport protocol))
(close-input-port (protocol2-inport protocol)))
(define (protocol2:lock protocol status)
(set-protocol2-status! protocol status)
(set-protocol2-key! protocol 0)
0)
(define (protocol2:lock/key protocol status)
(set-protocol2-status! protocol status)
(let [(key lock-counter)]
(set-protocol2-key! protocol key)
(set! lock-counter (add1 lock-counter))
key))
(define protocol2:unlock
(case-lambda
[(protocol status)
(protocol2:unlock protocol status 0)]
[(protocol status key)
(let ([protocol-status
(if protocol
(protocol2-status protocol)
protocol:lock:disconnected)]
[protocol-key (if protocol (protocol2-key protocol) 0)])
(unless (and (eq? protocol-status status)
(= protocol-key key))
(raise-sp-user-error 'lock "backend link is locked on state ~a"
protocol-status)))]))
(define (protocol2:encode protocol message)
(encode-message message (protocol2-outport protocol)))
(define (message-generator:current/next mg)
(force (message-generator-promise mg)))
(define (message-generator:current mg)
(let-values [((current next) (force (message-generator-promise mg)))]
current))
(define (message-generator:next mg)
(let-values [((current next) (force (message-generator-promise mg)))]
next))
(define (message-generator:done? mg)
(message-generator-done? mg))
(define (message-generator:new protocol)
(make-message-generator
protocol
(delay
(let [(next-message (parse-message protocol))]
(values next-message
(if (end-of-exchange-message? next-message)
(make-message-generator protocol #f #t)
(message-generator:new protocol)))))
#f))
(define (parse-message protocol)
(with-handlers [(exn:fail?
(lambda (e)
(make-FatalErrorResponse
"FATAL"
(format "Error communicating with backend: ~a" e)
0)))]
(parse-response protocol)))
(define (encode-message msg outport)
(cond
[(CancelRequest? msg)
(write-int32 outport 16)
(write-int32 outport 80877102)
(write-int32 outport (CancelRequest-process-id msg))
(write-int32 outport (CancelRequest-secret-key msg))]
[(PasswordPacket? msg)
(let [(ep (PasswordPacket-password msg))]
(write-int32 outport (+ 5 (string-length ep)))
(write-tstring outport ep))]
[(Query? msg)
(write-char #\Q outport)
(write-tstring outport (Query-sql msg))]
[(StartupPacket? msg)
(write-int32 outport 296)
(write-int16 outport (car (StartupPacket-ver msg)))
(write-int16 outport (cdr (StartupPacket-ver msg)))
(write-limstring outport 64 (StartupPacket-db msg))
(write-limstring outport 32 (StartupPacket-user msg))
(write-limstring outport 64 (StartupPacket-cmdline msg))
(write-limstring outport 64 (StartupPacket-unused msg))
(write-limstring outport 64 (StartupPacket-tty msg))]
[(Terminate? msg)
(write-char #\X outport)]
[(NotificationResponse? msg)
(write-char #\A outport)
(write-int32 outport (NotificationResponse-process-id msg))
(write-tstring outport (NotificationResponse-condition msg))]
[(AsciiRow? msg)
(write-char #\D outport)
(let [(fields (AsciiRow-fields msg))]
(encode-NullFields (map sql-null? fields) outport)
(for-each (lambda (field)
(write-int32 outport (+ 4 (string-length field)))
(write-astring outport field))
(filter (lambda (f) (not (sql-null? f))) fields)))]
[(BinaryRow? msg)
(write-char #\B outport)
(let [(fields (BinaryRow-fields msg))]
(encode-NullFields (map sql-null? fields) outport)
(for-each (lambda (field)
(write-int32 outport (+ 4 (string-length field)))
(write-bytes field outport))
(filter (lambda (f) (not (sql-null? f))) fields)))]
[(CompletedResponse? msg)
(write-char #\C outport)
(write-tstring outport (CompletedResponse-command msg))]
[(FatalErrorResponse? msg)
(write-char #\E outport)
(cond [(and (MessageResponse-type msg) (FatalErrorResponse-level msg))
(write-tstring outport
(format "~a ~a: ~a"
(MessageResponse-type msg)
(MessageResponse-message msg)
(FatalErrorResponse-level msg)))]
[(MessageResponse-type msg)
(write-tstring outport
(format "~a: ~a"
(MessageResponse-type msg)
(MessageResponse-message msg)))]
[else
(write-tstring outport (MessageResponse-message msg))])]
[(ErrorResponse? msg)
(write-char #\E outport)
(cond [(MessageResponse-type msg)
(write-tstring outport
(format "~a: ~a"
(MessageResponse-type msg)
(MessageResponse-message msg)))]
[else
(write-tstring outport (MessageResponse-message msg))])]
[(CopyInResponse? msg)
(write-char #\G outport)]
[(CopyOutResponse? msg)
(write-char #\H outport)
(encode-message (make-CopyDataRows (CopyOutResponse-rows msg)) outport)]
[(EmptyQueryResponse? msg)
(write-char #\I outport)
(write-tstring outport (EmptyQueryResponse-unused msg))]
[(BackendKeyData? msg)
(write-char #\K outport)
(write-int32 outport (BackendKeyData-process-id msg))
(write-int32 outport (BackendKeyData-secret-key msg))]
[(NoticeResponse? msg)
(write-char #\N outport)
(cond [(MessageResponse-type msg)
(write-tstring outport
(format "~a: ~a"
(MessageResponse-type msg)
(MessageResponse-message msg)))]
[else
(write-tstring outport (MessageResponse-message msg))])]
[(CursorResponse? msg)
(write-char #\P outport)
(write-tstring outport (CursorResponse-name msg))]
[(AuthenticationEncryptedPassword? msg)
(write-char #\R outport)
(write-int32 outport 4)
(write-bytes (AuthenticationEncryptedPassword-salt msg) outport)]
[(AuthenticationMD5Password? msg)
(write-char #\R outport)
(write-int32 outport 5)
(write-bytes (AuthenticationMD5Password-salt msg) outport)]
[(AuthenticationSCM? msg)
(write-char #\R outport)
(write-int32 outport 6)
(write-bytes (AuthenticationSCM-data msg) outport)]
[(Authentication? msg)
(write-char #\R outport)
(write-int32 outport
(case (Authentication-method msg)
[(ok) 0]
[(kerberosV4) 1]
[(kerberosV5) 2]
[(unencrypted-password) 3]))]
[(RowDescription? msg)
(write-char #\T outport)
(write-int16 outport (length (RowDescription-fields msg)))
(for-each (lambda (fi)
(write-tstring outport (FieldInfo-name fi))
(write-int32 outport (FieldInfo-oid fi))
(write-int16 outport (FieldInfo-tsize fi))
(write-int32 outport (FieldInfo-tmod fi)))
(RowDescription-fields msg))]
[(ReadyForQuery? msg)
(write-char #\Z outport)])
(flush-output outport))
(define (encode-NullFields null-fields outport)
(let* [(fields-length (length null-fields))
(bytes-needed (ceiling (/ fields-length 8)))
(bytelist
(let byteloop [(bytes-left bytes-needed)
(bytes '())
(fields-left null-fields)]
(if (zero? bytes-left)
(reverse bytes)
(let bitloop [(bit 7) (byte 0) (fields-left fields-left)]
(if (or (< bit 0) (null? fields-left))
(byteloop (sub1 bytes-needed)
(cons byte bytes)
fields-left)
(bitloop (sub1 bit)
(if (car fields-left)
byte
(bitwise-ior byte (arithmetic-shift 1 bit)))
(cdr fields-left)))))))]
(write-bytes (apply bytes bytelist) outport)))
(define (parse-response protocol)
(let* [(inport (protocol2-inport protocol))
(c (read-char inport))]
(cond
[(eq? c #\A)
(make-NotificationResponse (read-int32 inport)
(read-tstring inport))]
[(eq? c #\B)
(parse-BinaryRow inport protocol)]
[(eq? c #\C)
(make-CompletedResponse (read-tstring inport))]
[(eq? c #\D)
(parse-AsciiRow inport protocol)]
[(eq? c #\E)
(parse-ErrorResponse inport)]
[(eq? c #\I)
(make-EmptyQueryResponse (read-tstring inport))]
[(eq? c #\K)
(make-BackendKeyData (read-int32 inport) (read-int32 inport))]
[(eq? c #\N)
(parse-NoticeResponse inport)]
[(eq? c #\P)
(make-CursorResponse (read-tstring inport))]
[(eq? c #\R)
(parse-Authentication inport)]
[(eq? c #\T)
(parse-RowDescription inport protocol)]
[(eq? c #\Z)
(make-ReadyForQuery)]
[(memq c '(#\G #\H #\V))
(error 'protocol
"unsupported feature: copy in, copy out, or function call")]
[else
(error (format "unknown response code ~a" c))])))
(define (parse-Authentication port)
(let [(n (read-int32 port))]
(cond [(= n 0)
(make-Authentication 'ok)]
[(= n 1)
(make-Authentication 'kerberosV4)]
[(= n 2)
(make-Authentication 'kerberosV5)]
[(= n 3)
(make-Authentication 'unencrypted-password)]
[(= n 4)
(make-AuthenticationEncryptedPassword
'encrypted-password
(read-bytes 2 port))]
[(= n 5)
(make-AuthenticationMD5Password
'md5-password
(read-bytes 4 port))]
[(= n 6)
(make-AuthenticationSCM
'scm
(read-bytes 6 port))]
[else
(make-Authentication 'unknown)])))
(define (parse-ErrorResponse port)
(let* [(rawmsg (read-tstring port))
(fmt (regexp-match "([A-Z0-9 ]*): (.*)" rawmsg))]
(cond
[(not fmt) (make-ErrorResponse #f rawmsg)]
[(equal? (cadr fmt) "FATAL 1")
(make-FatalErrorResponse "FATAL" (caddr fmt) 1)]
[(equal? (cadr fmt) "FATAL 2")
(make-FatalErrorResponse "FATAL" (caddr fmt) 2)]
[(equal? (cadr fmt) "FATAL")
(make-FatalErrorResponse "FATAL" (caddr fmt) #f)]
[else (make-ErrorResponse (cadr fmt) (caddr fmt))])))
(define (parse-NoticeResponse port)
(let* [(rawmsg (read-tstring port))
(fmt (regexp-match "([A-Z0-9 ]*): (.*)" rawmsg))]
(cond [(not fmt) (make-NoticeResponse "NOTICE" rawmsg)]
[else (make-NoticeResponse (cadr fmt) (caddr fmt))])))
(define (parse-RowDescription port protocol)
(let [(numfields (read-int16 port))]
(set-protocol2-last-field-count! protocol numfields)
(make-RowDescription
(build-list numfields
(lambda (n)
(make-FieldInfo
(read-tstring port)
(read-int32 port)
(read-int16 port)
(read-int32 port)))))))
(define (parse-AsciiRow port protocol)
(let* [(last-field-count (protocol2-last-field-count protocol))
(nonnullfields (decode-NullFields port last-field-count))]
(make-AsciiRow
(build-list last-field-count
(lambda (n)
(if (select-bit nonnullfields n)
(let* [(raw (read-int32 port))
(runlen (- raw 4))]
(read-limstring port runlen))
sql-null ))))))
(define (parse-BinaryRow port protocol)
(let* [(last-field-count (protocol2-last-field-count protocol))
(nonnullfields (decode-NullFields port last-field-count))]
(make-BinaryRow
(build-list last-field-count
(lambda (n)
(if (select-bit nonnullfields n)
(let [(runlen (read-int32 port))]
(read-limstring port runlen))
sql-null))))))
(define (decode-NullFields port numfields)
(let [(bytes (read-limstring port (ceiling (/ numfields 8))))]
(build-vector
(string-length bytes)
(lambda (i)
(char->integer (string-ref bytes i))))))
(define (select-bit bits index)
(let* [(offset (modulo index 8))
(vindex (/ (- index offset) 8))]
(not (zero?
(bitwise-and (arithmetic-shift 1 (- 7 offset))
(vector-ref bits vindex))))))
(define (parse-FunctionResult/VoidResponse port)
(let [(c (read-char port))]
(cond [(eq? c #\G)
(let [(size (read-int32 port))]
(let [(answer
(make-FunctionResultResponse
(read-limstring port size)))]
(if (not (eq? (read-byte port) #\nul))
(error "Expected null at end of FunctionResult"))
answer))]
[(eq? c #\nul)
(make-FunctionVoidResponse)])))
)