sql-table.ss
(module sql-table
  mzscheme
  
  (provide (all-defined))
  
  (require-for-syntax (lib "contract.ss"))
  (require (lib "contract.ss"))
  (require (planet "sqli.scm" ("oesterholt" "sqlid.plt" 1 0)))
  
; type entry: keeps information for testing, encoding & decoding to SQL
;  (define-struct sql-type (predicate ->string <-string))
  (require-for-syntax "for-sql-syntax.ss")
  
;;; Errors: since SQLI only gives us a string, that is all we give.
  (define-struct sql-error (message))
  
  (define-syntax (define-sql-type stx)
    (syntax-case stx ()
      ((define-sql-type name predicate ->string <-string)
       (syntax
        (begin-for-syntax
          (hash-table-put! table-type-namespace
                           (syntax-object->datum (syntax name))
                           (make-sql-type (syntax predicate)
                                          (syntax ->string) 
                                          (syntax <-string))))))))
    
  ;;; default: gensymmed!
  (define-syntax (define-sql-default-type stx)
    (syntax-case stx ()
      ((define-sql-default-type predicate ->string <-string)
       (syntax
        (begin-for-syntax
          (hash-table-put! table-type-namespace 
                           default-type
                           (make-sql-type (syntax predicate)
                                          (syntax ->string)
                                          (syntax <-string))))))))

;;; turn the field syntax into a record.
  (define-for-syntax (parse-field stx)
    (syntax-case stx ()
      ((name type)
       (let ((type (hash-table-get table-type-namespace
                                   (syntax-object->datum (syntax type))
                                   (lambda ()
                                     (raise-syntax-error #f
                                                         "Unknown type"
                                                         (syntax type))))))
         (make-sql-field (syntax-object->datum (syntax name))
                         (symbol->string
                          (syntax-object->datum (syntax name)))
                         type
                         (sql-type-contract type)
                         (sql-type-->string type)
                         (sql-type-<-string type))))
      ((name column type)
       (let ((type (hash-table-get table-type-namespace
                                   (syntax-object->datum (syntax type))
                                   (lambda ()
                                     (raise-syntax-error #f "Unknown type"
                                                         (syntax type))))))
         (make-sql-field (syntax-object->datum (syntax name))
                         (symbol->string
                          (syntax-object->datum (syntax column)))
                         type
                         (sql-type-contract type)
                         (sql-type-->string type)
                         (sql-type-<-string type))))
      (name 
       (let ((type (hash-table-get table-type-namespace
                                   default-type
                                   (lambda ()
                                     (raise-syntax-error 
                                      #f "No default type"
                                      (syntax name))))))
         (make-sql-field (syntax-object->datum (syntax name))
                              (symbol->string
                               (syntax-object->datum (syntax name)))
                              type
                              (sql-type-contract type)
                              (sql-type-->string type)
                              (sql-type-<-string type))))))

;;; This parses an index specification into an internal form.
  (define-for-syntax (parse-index stx fields)
    (syntax-case stx ()
      (()
       (raise-syntax-error #f "Empty index is not allowed" stx))
      ((index ...)
       (let ((field-list (syntax->list (syntax (index ...)))))
         (map (lambda (stx)
                (let ((field-name (syntax-object->datum stx)))
                  (hash-table-get fields field-name
                    (lambda () (raise-syntax-error #f "Index is not a field"
                                                  stx)))))
              field-list)))
      (index
       (list
        (let ((field-name (syntax-object->datum stx)))
          (hash-table-get fields field-name
            (lambda () (raise-syntax-error #f "Index is not a field"
                                          stx))))))))

  
;;; This makes the where string for a compound index.
  (define-for-syntax (make-where-string index n)
    (if (pair? index)
        (if (pair? (cdr index))
            (string-append "$" (number->string n) " = " 
                           (sql-field-column (car index))
                           " AND " (make-where-string (cdr index) (+ n 1)))
            (string-append "$" (number->string n) " = " 
                           (sql-field-column (car index))))
;;; error condition should never happen
        (raise-syntax-error index "Empty index is not allowed")))
  
;;; This makes a selector string, without numbers.
  (define-for-syntax (make-selector-string index)
    (if (pair? index)
        (if (pair? (cdr index))
            (string-append (sql-field-column (car index))
                   `      ", " (make-selector-string (cdr index)))
            (sql-field-column (car index)))
;;; error condition should never happen
        (raise-syntax-error index "Empty index is not allowed")))

;;; This makes a list of syntax to extract the values of the primary.
  (define-for-syntax (make-extractor-syntax id index)
    (let loop ((n 0) (index index))
      (if (pair? index)
          (cons (quasisyntax (vector-ref (unsyntax id) 
                                         (unsyntax 
                                          (datum->syntax-object id n))))
                (loop (+ n 1) (cdr index)))
          '())))
  
;;; This makes the name of a compound index.
  (define-for-syntax (make-index-name index)
    (if (pair? index)
        (if (pair? (cdr index))
            (string-append (symbol->string (sql-field-name (car index)))
                           "-" (make-index-name (cdr index)))
            (symbol->string (sql-field-name (car index))))
;;; error condition should never happen
        (raise-syntax-error index "Empty index is not allowed")))
  
;;; This makes a contract for a compound index.
  (define-for-syntax (make-index-contract-list index)
    (map sql-field-contract index))
  
  (define-for-syntax (make-index-contract index)
    (quasisyntax (vector/c (unsyntax-splicing
                            (map (lambda (x) (syntax string/c)) index)))))
  
  
;;; This makes the getter and setter for columns.                    
  (define-for-syntax (make-field-functions stx table field primary)
    (let* ((row-contract (quasisyntax 
                          (cons/c sqli/c 
                                  (unsyntax (make-index-contract 
                                             primary)))))
           (field-name (sql-field-name field))
           (column-name (sql-field-column field))
           (contract (sql-field-contract field))
           (->string (sql-field-->string field))
           (<-string (sql-field-<-string field))
           (getter-name (string->symbol
                         (string-append
                          (symbol->string table)
                          "-"
                          (symbol->string field-name))))
           (setter-name (string->symbol
                         (string-append
                          "set-"
                          (symbol->string table)
                          "-"
                          (symbol->string field-name)
                          "!")))
           (getter-string (string-append
                           "SELECT "
                           column-name
                           " FROM "
                           (symbol->string table)
                           " WHERE "
                           (make-where-string primary 1)))
           (setter-string (string-append
                          "UPDATE "
                          (symbol->string table)
                          " SET "
                          column-name
                          " = $1 WHERE "
                          (make-where-string primary 2))))
      (quasisyntax
       (begin
         (define/contract (unsyntax (datum->syntax-object stx getter-name))
           (-> (unsyntax row-contract) (or/c (unsyntax contract) false/c))
           (lambda (obj)
             (let ((hook (car obj))
                   (id (cdr obj)))
               (if (sqli-query hook (unsyntax 
                                     (datum->syntax-object stx 
                                                           getter-string))
                               (unsyntax-splicing 
                                (make-extractor-syntax 
                                 (syntax id) primary)))
                   (raise (make-sql-error (sqli-error-message hook)))
                   (let ((lst (sqli-fetchrow hook)))
                     (and (pair? lst) ((unsyntax <-string) (car lst))))))))
         (define/contract (unsyntax (datum->syntax-object stx setter-name))
           (-> (unsyntax row-contract) (unsyntax contract) any)
           (lambda (obj val)
             (let ((hook (car obj))
                   (id (cdr obj)))
               (if (sqli-query hook (unsyntax
                                     (datum->syntax-object stx 
                                                           setter-string))
                               ((unsyntax ->string) val)
                               (unsyntax-splicing
                                (make-extractor-syntax
                                 (syntax id) primary)))
                   (raise (make-sql-error (sqli-error-message 
                                           hook)))))))))))
  
;;; This makes the getter function for rows.
  (define-for-syntax (make-getter-function stx table index primary)
    (let* ((row/c (quasisyntax (cons/c sqli/c 
                                       (unsyntax 
                                        (make-index-contract primary)))))
           (getter-name (string->symbol
                         (string-append
                          "get-"
                          (symbol->string table)
                          "-by-"
                          (make-index-name index))))
           (getter-string (string-append
                           "SELECT "
                           (make-selector-string primary)
                          " FROM "
                          (symbol->string table)
                          " WHERE "
                          (make-where-string index 1)))
           (args (map (lambda (x) (datum->syntax-object stx (gensym))) 
                      index))
           (->string-args (map sql-field-->string index))
           (use-args (map (lambda (x ->string)
                            (quasisyntax ((unsyntax ->string) 
                                          (unsyntax x))))
                          args ->string-args)))
      (quasisyntax
       (define/contract (unsyntax (datum->syntax-object stx getter-name))
         (-> sqli/c (unsyntax-splicing (make-index-contract-list index))
             (listof (unsyntax row/c)))
         (lambda (hook (unsyntax-splicing args))
           (if (sqli-query hook (unsyntax 
                                 (datum->syntax-object stx getter-string))
                           (unsyntax-splicing use-args))
               (raise (make-sql-error (sqli-error-message hook)))
               (let loop ((lst (sqli-fetchrow hook)))
                 (if (pair? lst) 
                     (cons (cons hook (list->vector lst))
                           (loop (sqli-fetchrow hook)))
                     '()))))))))
  
;;; This makes the adder SQL query.
  (define-for-syntax (make-adder-string stx table formal-list actual-list)
    (letrec ((get-col (lambda (sym)
                        (let ((entry (assq sym formal-list)))
                          (if entry (cdr entry)
                              (raise-syntax-error sym
                                                  "Field not found"
                                                  stx)))))
             (str-fun (lambda (lst num)
                        (if (pair? lst)
                            (string-append
                             ", "
                             (get-col (car lst))
                             " = $"
                             (number->string num)
                             (str-fun (cdr lst) (+ num 1)))
                            ""))))
      (if (pair? actual-list)
          (string-append "INSERT INTO "
                         (symbol->string table)
                         " SET "
                         (get-col (car actual-list))
                         " = $1"
                         (str-fun (cdr actual-list) 2)
                         ";")
          (raise-syntax-error #f "Empty adds are not allowed" stx))))
           
    
;;; This makes syntax that makes a row-adder function. 
  (define-for-syntax (make-adder-syntax stx table field-list)
    (let* ((adder-name (string->symbol
                        (string-append "add-"
                                       (symbol->string table))))
           (col-list (map (lambda (x) (cons (sql-field-name x)
                                            (sql-field-column x)))
                          field-list)))
      (quasisyntax
       (define-syntax ((unsyntax (datum->syntax-object stx adder-name)) 
                       stx-prime)
         (syntax-case stx-prime ()
           ((_ fields (unsyntax (datum->syntax-object stx '...)))
            (let* ((actual-list (syntax-object->datum 
                                 (syntax (fields (unsyntax (datum->syntax-object stx '...))))))
                   (adder-string 
                    (make-adder-string 
                     stx-prime 
                     (quote (unsyntax (datum->syntax-object stx table)))
                     '(unsyntax (datum->syntax-object stx col-list))
                     actual-list))
                   (args-list (map (lambda (x) (gensym)) actual-list)))
              (datum->syntax-object
               stx-prime
               `(lambda (hook ,@args-list)
                  (if (sqli-query hook ,adder-string ,@args-list)
                      (raise (make-sql-error (sqli-error-message 
                                              hook)))))))))))))
  
;;; Pulls it all together; mostly calls the above functions.
  (define-syntax (define-table stx)
    (syntax-case stx ()
      ((define-table name (field ...) (primary-key index ...))
       (let* ((fields (syntax->list (syntax (field ...))))
              (field-table (make-hash-table))
              (field-list (map parse-field fields))
              (table (syntax-object->datum (syntax name))))
         (for-each (lambda (x)
                     (if (hash-table-get field-table (sql-field-name x)
                                         (lambda () #f))
                         (raise-syntax-error (sql-field-name x)
                                             "Duplicate field"))
                     (hash-table-put! field-table (sql-field-name x) x))
                   field-list)
         (let* ((indices (syntax->list (syntax (primary-key index ...))))
                (index-list (map (lambda (x) 
                                   (parse-index x field-table))
                                   indices))
                (primary (car index-list)))
           (quasisyntax
            (begin
              (unsyntax-splicing 
               (map (lambda (x)
                      (make-field-functions stx table x primary))
                    field-list))
              (unsyntax-splicing
               (map (lambda (x)
                      (make-getter-function stx table x primary))
                    index-list))
              (unsyntax (make-adder-syntax stx table field-list)))))))))
  
  ;;; need additional anti-injection-attacking here?
  ;;; beyond that done by SQLID?
  (define identity-function (lambda (x) x))
  
  (define string/c (flat-contract string?))
  (define integer/c (flat-contract integer?))
  
  (define sqli/c (flat-contract sqli?))
  
  (define-sql-type string string/c identity-function identity-function)
  (define-sql-default-type string/c identity-function identity-function)
  (define-sql-type integer integer/c number->string string->number)
    
  )