(module sql-table
mzscheme
(provide (all-defined))
(require-for-syntax (lib "contract.ss"))
(require (lib "contract.ss"))
(require (planet "sqli.scm" ("oesterholt" "sqlid.plt" 1 0)))
(require-for-syntax "for-sql-syntax.ss")
(define-struct sql-error (message))
(define-syntax (define-sql-type stx)
(syntax-case stx ()
((define-sql-type name predicate ->string <-string)
(syntax
(begin-for-syntax
(hash-table-put! table-type-namespace
(syntax-object->datum (syntax name))
(make-sql-type (syntax predicate)
(syntax ->string)
(syntax <-string))))))))
(define-syntax (define-sql-default-type stx)
(syntax-case stx ()
((define-sql-default-type predicate ->string <-string)
(syntax
(begin-for-syntax
(hash-table-put! table-type-namespace
default-type
(make-sql-type (syntax predicate)
(syntax ->string)
(syntax <-string))))))))
(define-for-syntax (parse-field stx)
(syntax-case stx ()
((name type)
(let ((type (hash-table-get table-type-namespace
(syntax-object->datum (syntax type))
(lambda ()
(raise-syntax-error #f
"Unknown type"
(syntax type))))))
(make-sql-field (syntax-object->datum (syntax name))
(symbol->string
(syntax-object->datum (syntax name)))
type
(sql-type-contract type)
(sql-type-->string type)
(sql-type-<-string type))))
((name column type)
(let ((type (hash-table-get table-type-namespace
(syntax-object->datum (syntax type))
(lambda ()
(raise-syntax-error #f "Unknown type"
(syntax type))))))
(make-sql-field (syntax-object->datum (syntax name))
(symbol->string
(syntax-object->datum (syntax column)))
type
(sql-type-contract type)
(sql-type-->string type)
(sql-type-<-string type))))
(name
(let ((type (hash-table-get table-type-namespace
default-type
(lambda ()
(raise-syntax-error
#f "No default type"
(syntax name))))))
(make-sql-field (syntax-object->datum (syntax name))
(symbol->string
(syntax-object->datum (syntax name)))
type
(sql-type-contract type)
(sql-type-->string type)
(sql-type-<-string type))))))
(define-for-syntax (parse-index stx fields)
(syntax-case stx ()
(()
(raise-syntax-error #f "Empty index is not allowed" stx))
((index ...)
(let ((field-list (syntax->list (syntax (index ...)))))
(map (lambda (stx)
(let ((field-name (syntax-object->datum stx)))
(hash-table-get fields field-name
(lambda () (raise-syntax-error #f "Index is not a field"
stx)))))
field-list)))
(index
(list
(let ((field-name (syntax-object->datum stx)))
(hash-table-get fields field-name
(lambda () (raise-syntax-error #f "Index is not a field"
stx))))))))
(define-for-syntax (make-where-string index n)
(if (pair? index)
(if (pair? (cdr index))
(string-append "$" (number->string n) " = "
(sql-field-column (car index))
" AND " (make-where-string (cdr index) (+ n 1)))
(string-append "$" (number->string n) " = "
(sql-field-column (car index))))
(raise-syntax-error index "Empty index is not allowed")))
(define-for-syntax (make-selector-string index)
(if (pair? index)
(if (pair? (cdr index))
(string-append (sql-field-column (car index))
` ", " (make-selector-string (cdr index)))
(sql-field-column (car index)))
(raise-syntax-error index "Empty index is not allowed")))
(define-for-syntax (make-extractor-syntax id index)
(let loop ((n 0) (index index))
(if (pair? index)
(cons (quasisyntax (vector-ref (unsyntax id)
(unsyntax
(datum->syntax-object id n))))
(loop (+ n 1) (cdr index)))
'())))
(define-for-syntax (make-index-name index)
(if (pair? index)
(if (pair? (cdr index))
(string-append (symbol->string (sql-field-name (car index)))
"-" (make-index-name (cdr index)))
(symbol->string (sql-field-name (car index))))
(raise-syntax-error index "Empty index is not allowed")))
(define-for-syntax (make-index-contract-list index)
(map sql-field-contract index))
(define-for-syntax (make-index-contract index)
(quasisyntax (vector/c (unsyntax-splicing
(map (lambda (x) (syntax string/c)) index)))))
(define-for-syntax (make-field-functions stx table field primary)
(let* ((row-contract (quasisyntax
(cons/c sqli/c
(unsyntax (make-index-contract
primary)))))
(field-name (sql-field-name field))
(column-name (sql-field-column field))
(contract (sql-field-contract field))
(->string (sql-field-->string field))
(<-string (sql-field-<-string field))
(getter-name (string->symbol
(string-append
(symbol->string table)
"-"
(symbol->string field-name))))
(setter-name (string->symbol
(string-append
"set-"
(symbol->string table)
"-"
(symbol->string field-name)
"!")))
(getter-string (string-append
"SELECT "
column-name
" FROM "
(symbol->string table)
" WHERE "
(make-where-string primary 1)))
(setter-string (string-append
"UPDATE "
(symbol->string table)
" SET "
column-name
" = $1 WHERE "
(make-where-string primary 2))))
(quasisyntax
(begin
(define/contract (unsyntax (datum->syntax-object stx getter-name))
(-> (unsyntax row-contract) (or/c (unsyntax contract) false/c))
(lambda (obj)
(let ((hook (car obj))
(id (cdr obj)))
(if (sqli-query hook (unsyntax
(datum->syntax-object stx
getter-string))
(unsyntax-splicing
(make-extractor-syntax
(syntax id) primary)))
(raise (make-sql-error (sqli-error-message hook)))
(let ((lst (sqli-fetchrow hook)))
(and (pair? lst) ((unsyntax <-string) (car lst))))))))
(define/contract (unsyntax (datum->syntax-object stx setter-name))
(-> (unsyntax row-contract) (unsyntax contract) any)
(lambda (obj val)
(let ((hook (car obj))
(id (cdr obj)))
(if (sqli-query hook (unsyntax
(datum->syntax-object stx
setter-string))
((unsyntax ->string) val)
(unsyntax-splicing
(make-extractor-syntax
(syntax id) primary)))
(raise (make-sql-error (sqli-error-message
hook)))))))))))
(define-for-syntax (make-getter-function stx table index primary)
(let* ((row/c (quasisyntax (cons/c sqli/c
(unsyntax
(make-index-contract primary)))))
(getter-name (string->symbol
(string-append
"get-"
(symbol->string table)
"-by-"
(make-index-name index))))
(getter-string (string-append
"SELECT "
(make-selector-string primary)
" FROM "
(symbol->string table)
" WHERE "
(make-where-string index 1)))
(args (map (lambda (x) (datum->syntax-object stx (gensym)))
index))
(->string-args (map sql-field-->string index))
(use-args (map (lambda (x ->string)
(quasisyntax ((unsyntax ->string)
(unsyntax x))))
args ->string-args)))
(quasisyntax
(define/contract (unsyntax (datum->syntax-object stx getter-name))
(-> sqli/c (unsyntax-splicing (make-index-contract-list index))
(listof (unsyntax row/c)))
(lambda (hook (unsyntax-splicing args))
(if (sqli-query hook (unsyntax
(datum->syntax-object stx getter-string))
(unsyntax-splicing use-args))
(raise (make-sql-error (sqli-error-message hook)))
(let loop ((lst (sqli-fetchrow hook)))
(if (pair? lst)
(cons (cons hook (list->vector lst))
(loop (sqli-fetchrow hook)))
'()))))))))
(define-for-syntax (make-adder-string stx table formal-list actual-list)
(letrec ((get-col (lambda (sym)
(let ((entry (assq sym formal-list)))
(if entry (cdr entry)
(raise-syntax-error sym
"Field not found"
stx)))))
(str-fun (lambda (lst num)
(if (pair? lst)
(string-append
", "
(get-col (car lst))
" = $"
(number->string num)
(str-fun (cdr lst) (+ num 1)))
""))))
(if (pair? actual-list)
(string-append "INSERT INTO "
(symbol->string table)
" SET "
(get-col (car actual-list))
" = $1"
(str-fun (cdr actual-list) 2)
";")
(raise-syntax-error #f "Empty adds are not allowed" stx))))
(define-for-syntax (make-adder-syntax stx table field-list)
(let* ((adder-name (string->symbol
(string-append "add-"
(symbol->string table))))
(col-list (map (lambda (x) (cons (sql-field-name x)
(sql-field-column x)))
field-list)))
(quasisyntax
(define-syntax ((unsyntax (datum->syntax-object stx adder-name))
stx-prime)
(syntax-case stx-prime ()
((_ fields (unsyntax (datum->syntax-object stx '...)))
(let* ((actual-list (syntax-object->datum
(syntax (fields (unsyntax (datum->syntax-object stx '...))))))
(adder-string
(make-adder-string
stx-prime
(quote (unsyntax (datum->syntax-object stx table)))
'(unsyntax (datum->syntax-object stx col-list))
actual-list))
(args-list (map (lambda (x) (gensym)) actual-list)))
(datum->syntax-object
stx-prime
`(lambda (hook ,@args-list)
(if (sqli-query hook ,adder-string ,@args-list)
(raise (make-sql-error (sqli-error-message
hook)))))))))))))
(define-syntax (define-table stx)
(syntax-case stx ()
((define-table name (field ...) (primary-key index ...))
(let* ((fields (syntax->list (syntax (field ...))))
(field-table (make-hash-table))
(field-list (map parse-field fields))
(table (syntax-object->datum (syntax name))))
(for-each (lambda (x)
(if (hash-table-get field-table (sql-field-name x)
(lambda () #f))
(raise-syntax-error (sql-field-name x)
"Duplicate field"))
(hash-table-put! field-table (sql-field-name x) x))
field-list)
(let* ((indices (syntax->list (syntax (primary-key index ...))))
(index-list (map (lambda (x)
(parse-index x field-table))
indices))
(primary (car index-list)))
(quasisyntax
(begin
(unsyntax-splicing
(map (lambda (x)
(make-field-functions stx table x primary))
field-list))
(unsyntax-splicing
(map (lambda (x)
(make-getter-function stx table x primary))
index-list))
(unsyntax (make-adder-syntax stx table field-list)))))))))
(define identity-function (lambda (x) x))
(define string/c (flat-contract string?))
(define integer/c (flat-contract integer?))
(define sqli/c (flat-contract sqli?))
(define-sql-type string string/c identity-function identity-function)
(define-sql-default-type string/c identity-function identity-function)
(define-sql-type integer integer/c number->string string->number)
)