private/protocol2.ss
;; Copyright 2000-2005 Ryan Culpepper
;; Released under the terms of the modified BSD license (see the file
;; COPYRIGHT for terms).

;; Implements the frontend/backend interface specified in PostgreSQL
;; documentation.  Defines many structure types to encode messages
;; from the Postgres backend, and provides methods to write messages
;; back to the backend.

(module protocol2 mzscheme
  (require (lib "etc.ss")
           (lib "list.ss")
           "bitbang.ss"
           "protocol-structures.ss"
           "exceptions.ss"
           "sql-types.ss")

  (provide protocol2:new
           protocol2:reset
           protocol2:close
           protocol2:lock
           protocol2:lock/key
           protocol2:unlock
           protocol2:encode
           
           message-generator:current
           message-generator:next
           message-generator:current/next
           message-generator:done?
           
           protocol:lock:disconnected
           protocol:lock:auth-required
           protocol:lock:copy-in
           protocol:lock:ready)
  
  ;; lock-counter : number
  (define lock-counter 0)

  (define protocol:lock:disconnected 'disconnected)
  (define protocol:lock:auth-required 'auth-required)
  (define protocol:lock:copy-in 'copy-in)
  (define protocol:lock:ready 'ready)
  
  (define-struct protocol2 
    (inport outport last-field-count message-generator status key))
  (define-struct message-generator (protocol promise done?))
  
  (define (end-of-exchange-message? msg)
    (or (ReadyForQuery? msg)
        (FatalErrorResponse? msg)))
  
  ;; protocol2:new : input-port output-port -> protocol
  (define (protocol2:new inport outport)
    (make-protocol2 inport outport 0 #f protocol:lock:disconnected #f))
  
  ;; protocol2:reset : protocol2 -> message-generator
  (define (protocol2:reset protocol)
    (let [(mg (protocol2-message-generator protocol))
          (inport (protocol2-inport protocol))]
      (let loop [(mg mg)]
        (when (and mg (not (message-generator:done? mg)))
          (loop (message-generator:next mg))))
      (let [(new-mg (message-generator:new protocol))]
        (set-protocol2-message-generator! protocol new-mg)
        new-mg)))

  (define (protocol2:close protocol)
    (close-output-port (protocol2-outport protocol))
    (close-input-port (protocol2-inport protocol)))
  
  (define (protocol2:lock protocol status)
    (set-protocol2-status! protocol status)
    (set-protocol2-key! protocol 0)
    0)
  
  (define (protocol2:lock/key protocol status)
    (set-protocol2-status! protocol status)
    (let [(key lock-counter)]
      (set-protocol2-key! protocol key)
      (set! lock-counter (add1 lock-counter))
      key))
  
  (define protocol2:unlock 
    (case-lambda 
      [(protocol status)
       (protocol2:unlock protocol status 0)]
      [(protocol status key)
       (let ([protocol-status
              (if protocol
                  (protocol2-status protocol)
                  protocol:lock:disconnected)]
             [protocol-key (if protocol (protocol2-key protocol) 0)])
         (unless (and (eq? protocol-status status)
                      (= protocol-key key))
           (raise-sp-user-error 'lock "backend link is locked on state ~a"
                                protocol-status)))]))
  
  (define (protocol2:encode protocol message)
    (encode-message message (protocol2-outport protocol)))
  

  (define (message-generator:current/next mg)
    (force (message-generator-promise mg)))
  (define (message-generator:current mg)
    (let-values [((current next) (force (message-generator-promise mg)))]
      current))
  (define (message-generator:next mg)
    (let-values [((current next) (force (message-generator-promise mg)))]
      next))
  (define (message-generator:done? mg)
    (message-generator-done? mg))
  
  (define (message-generator:new protocol)
    (make-message-generator 
     protocol
     (delay
       (let [(next-message (parse-message protocol))]
         (values next-message
                 (if (end-of-exchange-message? next-message)
                     (make-message-generator protocol #f #t)
                     (message-generator:new protocol)))))
     #f))
  
  ;; parse-message : protocol -> Response
  (define (parse-message protocol)
    (with-handlers [(exn:fail?
                     (lambda (e) 
                       (make-FatalErrorResponse 
                        "FATAL" 
                        (format "Error communicating with backend: ~a" e)
                        0)))]
      (parse-response protocol)))
  
  (define (encode-message msg outport)
    (cond
      ;; Messages sent by client
      [(CancelRequest? msg)
       (write-int32 outport 16)
       (write-int32 outport 80877102)
       (write-int32 outport (CancelRequest-process-id msg))
       (write-int32 outport (CancelRequest-secret-key msg))]
;      [(CopyDataRows? msg)
;       (for-each (lambda (row)
;                   (write-astring outport row)
;                   (write-char #\newline outport))
;                 (CopyDataRows-rows msg))
;       (write-astring outport "\\.")
;       (write-char #\newline outport)]
      [(PasswordPacket? msg)
       (let [(ep (PasswordPacket-password msg))]
         (write-int32 outport (+ 5 (string-length ep)))
         (write-tstring outport ep))]
;      [(FunctionCall? msg)
;       (write-char #\F outport)
;       (write-tstring outport "")
;       (write-int32 outport (FunctionCall-oid msg))
;       (write-int32 outport (length (FunctionCall-arglist msg)))
;       (for-each
;        (lambda (arg)
;          (write-int32 outport (string-length arg))
;          (write-bytes arg outport))
;        (FunctionCall-arglist msg))]
      [(Query? msg)
       (write-char #\Q outport)
       (write-tstring outport (Query-sql msg))]
      [(StartupPacket? msg)
       (write-int32 outport 296)
       (write-int16 outport (car (StartupPacket-ver msg)))
       (write-int16 outport (cdr (StartupPacket-ver msg)))
       (write-limstring outport 64 (StartupPacket-db msg))
       (write-limstring outport 32 (StartupPacket-user msg))
       (write-limstring outport 64 (StartupPacket-cmdline msg))
       (write-limstring outport 64 (StartupPacket-unused msg))
       (write-limstring outport 64 (StartupPacket-tty msg))]
      [(Terminate? msg)
       (write-char #\X outport)]
      
      ;; Structures only sent by backend
      [(NotificationResponse? msg)
       (write-char #\A outport)
       (write-int32 outport (NotificationResponse-process-id msg))
       (write-tstring outport (NotificationResponse-condition msg))]
      [(AsciiRow? msg)
       (write-char #\D outport)
       (let [(fields (AsciiRow-fields msg))]
         (encode-NullFields (map sql-null? fields) outport)
         (for-each (lambda (field)
                     (write-int32 outport (+ 4 (string-length field)))
                     (write-astring outport field))
                   (filter (lambda (f) (not (sql-null? f))) fields)))]
      [(BinaryRow? msg)
       (write-char #\B outport)
       (let [(fields (BinaryRow-fields msg))]
         (encode-NullFields (map sql-null? fields) outport)
         (for-each (lambda (field)
                     (write-int32 outport (+ 4 (string-length field)))
                     (write-bytes field outport))
                   (filter (lambda (f) (not (sql-null? f))) fields)))]
      [(CompletedResponse? msg)
       (write-char #\C outport)
       (write-tstring outport (CompletedResponse-command msg))]
      [(FatalErrorResponse? msg)
       (write-char #\E outport)
       (cond [(and (MessageResponse-type msg) (FatalErrorResponse-level msg))
              (write-tstring outport
                             (format "~a ~a: ~a"
                                     (MessageResponse-type msg)
                                     (MessageResponse-message msg)
                                     (FatalErrorResponse-level msg)))]
             [(MessageResponse-type msg)
              (write-tstring outport
                             (format "~a: ~a"
                                     (MessageResponse-type msg)
                                     (MessageResponse-message msg)))]
             [else
              (write-tstring outport (MessageResponse-message msg))])]
      [(ErrorResponse? msg) 
       (write-char #\E outport)
       (cond [(MessageResponse-type msg)
              (write-tstring outport
                             (format "~a: ~a"
                                     (MessageResponse-type msg)
                                     (MessageResponse-message msg)))]
             [else
              (write-tstring outport (MessageResponse-message msg))])]
      [(CopyInResponse? msg) 
       (write-char #\G outport)]
      [(CopyOutResponse? msg) 
       (write-char #\H outport)
       (encode-message (make-CopyDataRows (CopyOutResponse-rows msg)) outport)]
      [(EmptyQueryResponse? msg)
       (write-char #\I outport)
       (write-tstring outport (EmptyQueryResponse-unused msg))]
      [(BackendKeyData? msg)
       (write-char #\K outport)
       (write-int32 outport (BackendKeyData-process-id msg))
       (write-int32 outport (BackendKeyData-secret-key msg))]
      [(NoticeResponse? msg)
       (write-char #\N outport)
       (cond [(MessageResponse-type msg)
              (write-tstring outport
                             (format "~a: ~a"
                                     (MessageResponse-type msg)
                                     (MessageResponse-message msg)))]
             [else
              (write-tstring outport (MessageResponse-message msg))])]
      [(CursorResponse? msg)
       (write-char #\P outport)
       (write-tstring outport (CursorResponse-name msg))]
      [(AuthenticationEncryptedPassword? msg)
       (write-char #\R outport)
       (write-int32 outport 4)
       (write-bytes (AuthenticationEncryptedPassword-salt msg) outport)]
      [(AuthenticationMD5Password? msg)
       (write-char #\R outport)
       (write-int32 outport 5)
       (write-bytes (AuthenticationMD5Password-salt msg) outport)]
      [(AuthenticationSCM? msg)
       (write-char #\R outport)
       (write-int32 outport 6)
       (write-bytes (AuthenticationSCM-data msg) outport)]
      [(Authentication? msg)
       (write-char #\R outport)
       (write-int32 outport 
                    (case (Authentication-method msg)
                      [(ok) 0]
                      [(kerberosV4) 1]
                      [(kerberosV5) 2]
                      [(unencrypted-password) 3]))]
      [(RowDescription? msg)
       (write-char #\T outport)
       (write-int16 outport (length (RowDescription-fields msg)))
       (for-each (lambda (fi)
                   (write-tstring outport (FieldInfo-name fi))
                   (write-int32 outport (FieldInfo-oid fi))
                   (write-int16 outport (FieldInfo-tsize fi))
                   (write-int32 outport (FieldInfo-tmod fi)))
                 (RowDescription-fields msg))]
      [(ReadyForQuery? msg)
       (write-char #\Z outport)])
    (flush-output outport))
  
  (define (encode-NullFields null-fields outport)
    (let* [(fields-length (length null-fields))
           (bytes-needed (ceiling (/ fields-length 8)))
           (bytelist
            (let byteloop [(bytes-left bytes-needed) 
                           (bytes '())
                           (fields-left null-fields)]
              (if (zero? bytes-left)
                  (reverse bytes)
                  (let bitloop [(bit 7) (byte 0) (fields-left fields-left)]
                    (if (or (< bit 0) (null? fields-left))
                        (byteloop (sub1 bytes-needed) 
                                  (cons byte bytes)
                                  fields-left)
                        (bitloop (sub1 bit) 
                                 (if (car fields-left)
                                     byte
                                     (bitwise-ior byte (arithmetic-shift 1 bit)))
                                 (cdr fields-left)))))))]
      (write-bytes (apply bytes bytelist) outport)))
  
  ;; Message Parsing
  
  (define (parse-response protocol)
    (let* [(inport (protocol2-inport protocol))
           (c (read-char inport))]
      (cond 
        [(eq? c #\A)
         (make-NotificationResponse (read-int32 inport) 
                                    (read-tstring inport))]
        [(eq? c #\B)
         (parse-BinaryRow inport protocol)]
        [(eq? c #\C)
         (make-CompletedResponse (read-tstring inport))]
        [(eq? c #\D)
         (parse-AsciiRow inport protocol)]
        [(eq? c #\E)
         (parse-ErrorResponse inport)]
;        [(eq? c #\G)
;         (make-CopyInResponse)]
;        [(eq? c #\H)
;         (make-CopyOutResponse (parse-CopyDataRows inport))]
        
        [(eq? c #\I)
         (make-EmptyQueryResponse (read-tstring inport))]
        
        [(eq? c #\K)
         (make-BackendKeyData (read-int32 inport) (read-int32 inport))]
        [(eq? c #\N) 
         (parse-NoticeResponse inport)]
        [(eq? c #\P)
         (make-CursorResponse (read-tstring inport))]
        [(eq? c #\R)
         (parse-Authentication inport)]
        [(eq? c #\T)
         (parse-RowDescription inport protocol)]
;        [(eq? c #\V)
;         (parse-FunctionResult/VoidResponse inport)]
        [(eq? c #\Z) 
         (make-ReadyForQuery)]
        [(memq c '(#\G #\H #\V))
         (error 'protocol 
                "unsupported feature: copy in, copy out, or function call")]
        [else
         (error (format "unknown response code ~a" c))])))
  
  ;; parse-Authentication : input-port -> msg
  (define (parse-Authentication port)
    (let [(n (read-int32 port))]
      (cond [(= n 0)
             (make-Authentication 'ok)]
            [(= n 1)
             (make-Authentication 'kerberosV4)]
            [(= n 2)
             (make-Authentication 'kerberosV5)]
            [(= n 3)
             (make-Authentication 'unencrypted-password)]
            [(= n 4)
             (make-AuthenticationEncryptedPassword 
              'encrypted-password
              (read-bytes 2 port))]
            [(= n 5)
             (make-AuthenticationMD5Password 
              'md5-password
              (read-bytes 4 port))]
            [(= n 6)
             (make-AuthenticationSCM 
              'scm
              (read-bytes 6 port))]
            [else
             (make-Authentication 'unknown)])))
  
;  ;; parse-CopyDataRows : input-port -> list<string>
;  (define (parse-CopyDataRows port)
;    (let [(line (read-line port))]
;      (if (string=? line (string #\\ #\.))
;          null
;          (cons line (parse-CopyDataRows port)))))
  
  ;; parse-ErrorResponse : input-port -> ErrorResponse
  (define (parse-ErrorResponse port)
    (let* [(rawmsg (read-tstring port))
           (fmt (regexp-match "([A-Z0-9 ]*): (.*)" rawmsg))]
      ;(printf "Parsing error: ~s~ngot ~s" rawmsg fmt)
      (cond
        [(not fmt) (make-ErrorResponse #f rawmsg)]
        [(equal? (cadr fmt) "FATAL 1") 
         (make-FatalErrorResponse "FATAL" (caddr fmt) 1)]
        [(equal? (cadr fmt) "FATAL 2")
         (make-FatalErrorResponse "FATAL" (caddr fmt) 2)]
        [(equal? (cadr fmt) "FATAL")
         (make-FatalErrorResponse "FATAL" (caddr fmt) #f)]
        [else (make-ErrorResponse (cadr fmt) (caddr fmt))])))
  
  (define (parse-NoticeResponse port)
    (let* [(rawmsg (read-tstring port))
           (fmt (regexp-match "([A-Z0-9 ]*): (.*)" rawmsg))]
      (cond [(not fmt) (make-NoticeResponse "NOTICE" rawmsg)]
            [else (make-NoticeResponse (cadr fmt) (caddr fmt))])))
  
  ;; parse-RowDescription : input-port protocol -> RowDescription
  (define (parse-RowDescription port protocol)
    (let [(numfields (read-int16 port))]
      (set-protocol2-last-field-count! protocol numfields)
      (make-RowDescription
       (build-list numfields 
                   (lambda (n)
                     (make-FieldInfo
                      (read-tstring port)
                      (read-int32 port)
                      (read-int16 port)
                      (read-int32 port)))))))
  
  ;; parse-AsciiRow : input-port protocol -> AsciiRow
  (define (parse-AsciiRow port protocol)
    (let* [(last-field-count (protocol2-last-field-count protocol))
           (nonnullfields (decode-NullFields port last-field-count))]
      ;;(fprintf (current-error-port) "NonNullFields: ~a~n" nonnullfields)
      (make-AsciiRow
       (build-list ;;vector
        last-field-count
        (lambda (n)
          (if (select-bit nonnullfields n)
              (let* [(raw (read-int32 port))
                     (runlen (- raw 4))]
                ;;(printf " -- reading field ~a with size ~a - 4~n" n raw)
                (read-limstring port runlen))
              sql-null ))))))
  
  ;; parse-BinaryRow : input-port protocol -> BinaryRow
  (define (parse-BinaryRow port protocol)
    (let* [(last-field-count (protocol2-last-field-count protocol))
           (nonnullfields (decode-NullFields port last-field-count))]
      (make-BinaryRow
       (build-list ;;vector
        last-field-count
        (lambda (n)
          (if (select-bit nonnullfields n)
              (let [(runlen (read-int32 port))]
                (read-limstring port runlen))
              sql-null))))))
  
  ;; decode-NullFields : input-port number -> (vector-of [0..127])
  (define (decode-NullFields port numfields)
    (let [(bytes (read-limstring port (ceiling (/ numfields 8))))]
      (build-vector 
       (string-length bytes)
       (lambda (i)
         (char->integer (string-ref bytes i))))))

  ;; select-bit : (vector-of [0..127]) integer -> boolean
  (define (select-bit bits index)
    (let* [(offset (modulo index 8))
           (vindex (/ (- index offset) 8))]
      (not (zero? 
            (bitwise-and (arithmetic-shift 1 (- 7 offset))
                         (vector-ref bits vindex))))))
  
  (define (parse-FunctionResult/VoidResponse port)
    (let [(c (read-char port))]
      (cond [(eq? c #\G)
             (let [(size (read-int32 port))]
               (let [(answer
                      (make-FunctionResultResponse 
                       (read-limstring port size)))]
                 (if (not (eq? (read-byte port) #\nul))
                     (error "Expected null at end of FunctionResult"))
                 answer))]
            [(eq? c #\nul)
             (make-FunctionVoidResponse)])))
                    
  )