#lang racket/base
(require racket/class
racket/match
openssl/sha1
"../generic/interfaces.rkt"
"../generic/prepared.rkt"
"../generic/sql-data.rkt"
"../generic/exceptions.rkt"
"message.rkt"
"dbsystem.rkt")
(provide connection%)
(define DEBUG-RESPONSES #f)
(define DEBUG-SENT-MESSAGES #f)
(define MAX-PACKET-LENGTH #x1000000)
(define connection%
(class* object% (connection<%>)
(define inport #f)
(define outport #f)
(super-new)
(define-syntax with-disconnect-on-error
(syntax-rules ()
[(with-disconnect-on-error . body)
(with-handlers ([exn:fail? (lambda (e) (disconnect* #f) (raise e))])
. body)]))
(define wlock (make-semaphore 1))
(define/private (lock who require-connected?)
(semaphore-wait wlock)
(when (and require-connected? (not outport))
(semaphore-post wlock)
(error who "not connected")))
(define/private (unlock)
(semaphore-post wlock))
(define/private (call-with-lock who proc
#:require-connected? [require-connected? #t])
(lock who require-connected?)
(with-handlers ([values (lambda (e) (unlock) (raise e))])
(begin0 (proc) (unlock))))
(define next-msg-num 0)
(define/private (fresh-exchange)
(set! next-msg-num 0))
(define/private (send-message msg)
(buffer-message msg)
(flush-message-buffer))
(define/private (buffer-message msg)
(when DEBUG-SENT-MESSAGES
(fprintf (current-error-port) " >> ~s\n" msg))
(with-disconnect-on-error
(write-packet outport msg next-msg-num)
(set! next-msg-num (add1 next-msg-num))))
(define/private (flush-message-buffer)
(with-disconnect-on-error
(flush-output outport)))
(define/private (recv fsym expectation [field-dvecs #f])
(define r
(with-disconnect-on-error
(recv* fsym expectation field-dvecs)))
(when DEBUG-RESPONSES
(eprintf " << ~s\n" r))
(when (error-packet? r)
(raise-backend-error fsym r))
r)
(define/private (recv* fsym expectation field-dvecs)
(define (advance . ss)
(unless (or (not expectation)
(null? ss)
(memq expectation ss))
(error fsym "internal error: unexpected packet")))
(define (err packet)
(error fsym "internal error: unexpected packet"))
(let-values ([(msg-num next) (parse-packet inport expectation field-dvecs)])
(set! next-msg-num (add1 msg-num))
(match next
[(? handshake-packet?)
(advance 'handshake)]
[(? ok-packet?)
(advance)]
[(? error-packet?)
(advance)]
[(struct result-set-header-packet (field-count _))
(advance 'result)]
[(? field-packet?)
(advance 'field)]
[(? row-data-packet?)
(advance 'data)]
[(? binary-row-data-packet?)
(advance 'binary-data)]
[(? ok-prepared-statement-packet? result)
(advance 'prep-ok)]
[(? parameter-packet? result)
(advance 'prep-params)]
[(? eof-packet?)
(advance 'field 'data 'binary-data 'prep-params)]
[else
(err next)])
next))
(define/public (disconnect)
(disconnect* #t))
(define/private (disconnect* lock-not-held?)
(define politely? lock-not-held?)
(define (go)
(when DEBUG-SENT-MESSAGES
(eprintf " ** Disconnecting\n"))
(when outport
(when politely?
(fresh-exchange)
(send-message (make-command-packet 'quit "")))
(close-output-port outport)
(set! outport #f))
(when inport
(close-input-port inport)
(set! inport #f)))
(cond [politely?
(call-with-lock 'disconnect go #:require-connected? #f)]
[else (go)]))
(define/public (connected?)
(let ([outport outport])
(and outport (not (port-closed? outport)))))
(define/public (get-dbsystem)
dbsystem)
(define/public (attach-to-ports in out)
(set! inport in)
(set! outport out))
(define/public (start-connection-protocol dbname username password)
(with-disconnect-on-error
(fresh-exchange)
(let ([r (recv 'mysql-connect 'handshake)])
(match r
[(struct handshake-packet (pver sver tid scramble capabilities charset status))
(check-required-flags capabilities)
(send-message
(make-client-authentication-packet
(desired-capabilities capabilities)
MAX-PACKET-LENGTH
'utf8-general-ci username
(scramble-password scramble password)
dbname))
(expect-auth-confirmation)]
[_
(error 'mysql-connect
"internal error: unknown message during authentication")]))))
(define/private (check-required-flags capabilities)
(for-each (lambda (rf)
(unless (memq rf capabilities)
(error 'mysql-connect
"server does not support required capability: ~s"
rf)))
REQUIRED-CAPABILITIES))
(define/private (desired-capabilities capabilities)
(cons 'interactive
(filter (lambda (c) (memq c DESIRED-CAPABILITIES))
capabilities)))
(define/private (expect-auth-confirmation)
(let ([r (recv 'mysql-connect 'auth)])
(match r
[(struct ok-packet (_ _ status warnings message))
(after-connect)]
[_
(error 'mysql-connect
"internal error: unknown message after authentication")])))
(define/private (after-connect)
(query 'mysql-connect "set names 'utf8'"
(lambda (fields ordered?) (values #f void void #f)))
(void))
(define name-counter 0)
(define/public (query fsym stmt collector)
(let-values ([(stmt result)
(call-with-lock fsym
(lambda ()
(let ([stmt (check-statement fsym stmt)])
(values stmt (query1 fsym stmt)))))])
(query1:process-result fsym collector result)))
(define/private (query1 fsym stmt)
(fresh-exchange)
(query1:enqueue stmt)
(query1:collect fsym (not (string? stmt))))
(define/private (check-statement fsym stmt)
(cond [(statement-binding? stmt)
(let ([pst (statement-binding-pst stmt)])
(send pst check-owner fsym this stmt)
stmt)]
[(and (string? stmt) (force-prepare-sql? fsym stmt))
(let ([pst (prepare1 fsym stmt #t)])
(send pst bind fsym null))]
[else stmt]))
(define/private (query1:enqueue stmt)
(cond [(statement-binding? stmt)
(let* ([pst (statement-binding-pst stmt)]
[id (send pst get-handle)]
[params (statement-binding-params stmt)]
[null-map (map sql-null? params)])
(send-message
(make-execute-packet id null null-map params)))]
[else (send-message (make-command-packet 'query stmt))]))
(define/private (query1:collect fsym binary?)
(let ([r (recv fsym 'result)])
(match r
[(struct ok-packet (affected-rows insert-id status warnings message))
(vector 'command `((affected-rows . ,affected-rows)
(insert-id . ,insert-id)
(status . ,status)
(message . ,message)))]
[(struct result-set-header-packet (fields extra))
(query1:expect-fields fsym null binary?)])))
(define/private (query1:expect-fields fsym r-field-dvecs binary?)
(let ([r (recv fsym 'field)])
(match r
[(? field-packet?)
(query1:expect-fields fsym (cons (parse-field-dvec r) r-field-dvecs) binary?)]
[(struct eof-packet (warning-count status))
(let ([field-dvecs (reverse r-field-dvecs)])
(vector 'recordset field-dvecs (query1:get-rows fsym field-dvecs binary?)))])))
(define/private (query1:get-rows fsym field-dvecs binary?)
(let ([r (recv fsym (if binary? 'binary-data 'data) field-dvecs)])
(match r
[(struct row-data-packet (data))
(cons data (query1:get-rows fsym field-dvecs binary?))]
[(struct binary-row-data-packet (data))
(cons data (query1:get-rows fsym field-dvecs binary?))]
[(struct eof-packet (warning-count status))
null])))
(define/private (query1:process-result fsym collector result)
(match result
[(vector 'recordset field-dvecs rows)
(let-values ([(init combine finalize headers?)
(collector (length field-dvecs) #t)])
(recordset (and headers? (map field-dvec->field-info field-dvecs))
(finalize
(for/fold ([acc init]) ([row (in-list rows)])
(combine acc row)))))]
[(vector 'command command-info)
(simple-result command-info)]))
(define/public (prepare fsym stmt close-on-exec?)
(call-with-lock fsym
(lambda ()
(prepare1 fsym stmt close-on-exec?))))
(define/private (prepare1 fsym stmt close-on-exec?)
(fresh-exchange)
(send-message (make-command-packet 'statement-prepare stmt))
(let ([r (recv fsym 'prep-ok)])
(match r
[(struct ok-prepared-statement-packet (id fields params))
(let ([param-dvecs
(if (zero? params) null (prepare1:get-field-descriptions fsym))]
[field-dvecs
(if (zero? fields) null (prepare1:get-field-descriptions fsym))])
(new prepared-statement%
(handle id)
(close-on-exec? close-on-exec?)
(param-typeids (map field-dvec->typeid param-dvecs))
(result-dvecs field-dvecs)
(owner this)))])))
(define/private (prepare1:get-field-descriptions fsym)
(let ([r (recv fsym 'field)])
(match r
[(struct eof-packet (warning-count status))
null]
[(? field-packet?)
(cons (parse-field-dvec r) (prepare1:get-field-descriptions fsym))])))
(define/public (free-statement pst)
(call-with-lock 'free-statement
#:require-connected? #f
(lambda ()
(let ([id (send pst get-handle)])
(when (and id outport) (send pst set-handle #f)
(fresh-exchange)
(send-message (make-command:statement-packet 'statement-close id)))))))))
(define (scramble-password scramble password)
(and scramble password
(let* ([password (string->bytes/latin-1 password)]
[stage1 (sha1-bytes (open-input-bytes password))]
[stage2 (sha1-bytes (open-input-bytes stage1))]
[stage3 (sha1-bytes (open-input-bytes (bytes-append scramble stage2)))]
[reply (bytes-xor stage1 stage3)])
reply)))
(define (bytes-xor a b)
(let ([c (make-bytes (bytes-length a))])
(let loop ([i 0])
(when (< i (bytes-length c))
(bytes-set! c i
(bitwise-xor (bytes-ref a i) (bytes-ref b i)))
(loop (add1 i))))
c))
(define REQUIRED-CAPABILITIES
'(long-flag
connect-with-db
protocol-41
secure-connection))
(define DESIRED-CAPABILITIES
'(long-password
long-flag
transactions
protocol-41
secure-connection
connect-with-db))
(define (raise-backend-error who r)
(define code (error-packet-sqlstate r))
(define message (error-packet-message r))
(define props (list (cons 'errno (error-packet-errno r))
(cons 'code code)
(cons 'message message)))
(raise-sql-error who code message props))
(define (force-prepare-sql? fsym stmt)
(let ([kw (get-sql-keyword stmt)])
(cond [(not kw)
#t]
[(string-ci=? kw "select") #t]
[(string-ci=? kw "call") #t]
[else #f])))
(define sql-statement-rx
#rx"^(?:(?:#[^\n\r]*[\n\r])|(?:-- [^\n\r]*[\n\r])|[ \t\n\r])*([A-Za-z]+)")
(define (get-sql-keyword stmt)
(let ([m (regexp-match sql-statement-rx stmt)])
(and m (cadr m))))