#lang scheme/base
(require scheme/contract
scheme/match
srfi/19/time
srfi/26/cut
(planet untyped/unlib:3/list)
(planet untyped/unlib:3/symbol)
"../../base.ss"
"lang-internal.ss"
"struct.ss"
"type.ss")
(define sql:alias
(match-lambda*
[(list (? symbol? id) (? table? item))
(make-table-alias id item)]
[(list (? symbol? id) (? select? item))
(make-select-alias id item)]
[(list (? symbol? id) (? non-alias-expression? item))
(make-expression-alias id item)]
[(list (and alias (struct table-alias (name table))) (? column? col))
(if (eq? (column-table col) table)
(make-column-alias alias col)
(raise-exn exn:fail:contract (format "table does not contain column: ~s ~s" table col)))]
[any (raise-exn exn:fail:contract (format "bad arguments: ~s" any))]))
(define (non-alias-expression? item)
(and (expression? item)
(not (expression-alias? item))
(not (column-alias? item))))
(define (sql:select #:what [what #f]
#:distinct [distinct #f]
#:from from
#:where [where #f]
#:group [group null]
#:order [order null]
#:having [having #f]
#:limit [limit #f]
#:offset [offset #f])
(sql:select/internal what distinct from where group order having limit offset))
(define (sql:select/internal what* distinct* from* where group* order having limit offset)
(define from
(sql-lift from*))
(define-values (what expand-info)
(expand-what-argument (if what* what* (make-default-what-argument from))))
(define distinct
(cond [(expression? distinct*) (list distinct*)]
[(eq? distinct* #t) (list)]
[(eq? distinct* #f) #f]
[else (raise-select-exn #:distinct "(U expression (listof expression) #t #f)" distinct*)]))
(define group
(expand-group-argument group*))
(unless (relation? from)
(raise-select-exn #:from "relation" from))
(unless (or (expression? where) (not where))
(raise-select-exn #:where "(U expression #f)" where))
(unless (and (list? group) (andmap expression? group))
(raise-select-exn #:group "(listof expression)" group))
(unless (and (list? order) (andmap order? order))
(raise-select-exn #:order "(listof order)" order))
(unless (or (expression? having) (not having))
(raise-select-exn #:having "(U expression #f)" having))
(unless (or (integer? limit) (not limit))
(raise-select-exn #:limit "(U integer #f)" limit))
(unless (or (integer? offset) (not offset))
(raise-select-exn #:offset "(U integer #f)" offset))
(let*-values ( [(relations) (relation->relations from)]
[(local-columns imported-columns) (relation->columns from)]
[(columns) (append local-columns imported-columns)]
[(columns*) (append what columns)])
(check-what-clause what relations columns)
(check-distinct-clause distinct relations columns)
(check-where-clause where relations columns)
(check-group-clause group relations columns*)
(check-order-clause order relations columns*)
(check-having-clause having relations columns*)
(make-select what distinct from where group order having limit offset local-columns imported-columns expand-info)))
(define (raise-select-exn kw expected received)
(raise-exn exn:fail:contract
(format "~a argument to select: expected ~a, received ~s" kw expected received)))
(define (sql:inner left right op)
(create-join 'inner left right (sql-lift op)))
(define (sql:left left right op)
(create-join 'left left right (sql-lift op)))
(define (sql:right left right op)
(create-join 'right left right (sql-lift op)))
(define (sql:outer left right)
(create-join 'outer left right #f))
(define (create-join op left right on)
(define ans (make-join op (sql-lift left) (sql-lift right) on))
(check-join ans)
ans)
(define type:boolean (make-boolean-type))
(define type:integer (make-integer-type))
(define type:real (make-real-type))
(define type:string (make-string-type #f))
(define type:symbol (make-symbol-type #f))
(define type:time-utc (make-time-utc-type))
(define type:time-tai (make-time-tai-type))
(define-function (sql:count arg)
(let ([arg (sql-lift arg)])
(make-aggregate type:integer 'count (list arg))))
(define sql:count*
(case-lambda
[() (make-aggregate type:integer 'count* null)]
[(alias) (make-aggregate type:integer 'count* (list alias))]))
(define (sql:min arg)
(let ([arg (sql-lift arg)])
(make-aggregate (expression-type arg) 'min (list arg))))
(define (sql:max arg)
(let ([arg (sql-lift arg)])
(make-aggregate (expression-type arg) 'max (list arg))))
(define sql:average
(let ([type type:real])
(lambda (alias)
(let ([arg (sql-lift alias)])
(if (numeric-expression? arg)
(make-aggregate type 'average (list arg))
(raise-type-error 'average (list arg)))))))
(define-function (sql:and . args)
[(andmap boolean-expression? args) type:boolean])
(define-function (sql:or . args)
[(andmap boolean-expression? args) type:boolean])
(define-function (sql:not arg)
[(boolean-expression? arg) type:boolean])
(define-function (sql:+ . args)
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real]
[(andmap time-tai-expression? args) type:time-tai]
[(andmap temporal-expression? args) type:time-utc])
(define-function (sql:- . args)
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real]
[(andmap time-tai-expression? args) type:time-tai]
[(andmap temporal-expression? args) type:time-utc])
(define-function (sql:* . args)
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real])
(define-function (sql:/ arg1 arg2)
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:real])
(define-function (sql:abs arg)
[(integer-expression? arg) type:integer]
[(numeric-expression? arg) type:real])
(define-function (sql:floor arg)
[(numeric-expression? arg) type:integer])
(define-function (sql:ceiling arg)
[(numeric-expression? arg) type:integer])
(define-function (sql:round arg)
[(numeric-expression? arg) type:integer])
(define-function (sql:like arg1 arg2)
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean])
(define-function (sql:regexp-match arg1 arg2)
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean])
(define-function (sql:regexp-match-ci arg1 arg2)
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean])
(define-function (sql:string-append . args)
[(andmap symbol-expression? args) type:symbol]
[(andmap character-expression? args) type:string])
(define-function (sql:string-replace arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace-ci arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace* arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace*-ci arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:= arg1 arg2)
[#t type:boolean])
(define-function (sql:<> arg1 arg2)
[#t type:boolean])
(define-function (sql:< arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:> arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:>= arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:<= arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:null? arg)
[#t type:boolean])
(define-function (sql:coalesce . args)
[(andmap boolean-expression? args) type:boolean]
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real]
[(andmap symbol-expression? args) type:symbol]
[(andmap string-expression? args) type:string]
[(andmap time-tai-expression? args) type:time-tai]
[(andmap temporal-expression? args) type:time-utc])
(define-function (sql:->string arg1 arg2)
[(string-expression? arg2) type:string])
(define-function (sql:->symbol arg1 arg2)
[(string-expression? arg2) type:symbol])
(define (sql:in arg1 arg2)
(let* ([arg1 (sql-lift arg1)]
[arg2 (if (list? arg2)
(map sql-lift arg2)
arg2)]
[type1 (expression-type arg1)]
[type2 (cond [(null? arg2) type1] [(pair? arg2) (let ([type2 (expression-type (car arg2))])
(unless (andmap (cut type-compatible? type2 <>)
(map expression-type arg2))
(raise-exn exn:fail:contract
(format "sql:in: list elements must all be of the same type: ~a" arg2)))
type2)]
[(select? arg2) (let ([columns (select-what arg2)])
(unless (= (length columns) 1)
(raise-exn exn:fail:contract
(format "sql:in: subselect must have exactly one column: ~a" arg2)))
(expression-type (car columns)))])])
(unless (type-compatible? type1 type2)
(raise-exn exn:fail:contract
(format "sql:in: type mismatch: argument types do not match: ~a ~a" type1 type2)))
(make-function type:boolean 'in (list arg1 arg2))))
(define sql:if
(case-lambda
[(test pos)
(define pos* (sql-lift pos))
(sql:if test pos* (sql:null (expression-type pos*)))]
[(test pos neg)
(let* ([test (sql-lift test)]
[pos (sql-lift pos)]
[neg (sql-lift neg)]
[type (cond [(and (boolean-expression? test) (andmap boolean-expression? (list pos neg))) type:boolean]
[(and (boolean-expression? test) (andmap integer-expression? (list pos neg))) type:integer]
[(and (boolean-expression? test) (andmap numeric-expression? (list pos neg))) type:real]
[(and (boolean-expression? test) (andmap symbol-expression? (list pos neg))) type:symbol]
[(and (boolean-expression? test) (andmap character-expression? (list pos neg))) type:string]
[(and (boolean-expression? test) (andmap time-tai-expression? (list pos neg))) type:time-tai]
[(and (boolean-expression? test) (andmap temporal-expression? (list pos neg))) type:time-utc]
[else (error (format "~a not defined for the types: ~s"
'id
(map expression-type (list test pos neg))))])])
(make-function type 'if (list test pos neg)))]))
(define-syntax sql:cond
(syntax-rules (else)
[(_ [test expr1])
(sql:if test expr1)]
[(_ [test expr1] [else expr2])
(sql:if test expr1 expr2)]
[(_ [test1 expr1] [test2 expr2] ...)
(sql:if test1 expr1 (sql:cond [test2 expr2] ...))]))
(define sql:literal make-literal)
(define (sql:null type)
(make-null-literal type))
(define (sql:order expr dir)
(make-order (sql-lift expr) dir))
(define sql:asc (cut sql:order <> 'asc))
(define sql:desc (cut sql:order <> 'desc))
(provide (rename-out [sql:alias alias]
[sql:cond cond]))
(provide/contract
[rename sql:select select (->* (#:from (or/c relation? select?))
(#:what (or/c expression? relation?
(listof (or/c expression? relation?))
false/c)
#:distinct (or/c expression? (listof expression?) boolean?)
#:where (or/c expression? false/c)
#:group (listof (or/c column? relation?))
#:order (listof order?)
#:having (or/c expression? false/c)
#:limit (or/c integer? false/c)
#:offset (or/c integer? false/c))
select?)]
[rename sql:select/internal select/internal (-> (or/c expression? relation? (listof (or/c expression? relation?)) false/c)
(or/c expression? (listof expression?) boolean?)
(or/c relation? select?)
(or/c expression? false/c)
(listof (or/c column? relation?))
(listof order?)
(or/c expression? false/c)
(or/c integer? false/c)
(or/c integer? false/c)
select?)]
[rename sql:count count (-> column-alias? aggregate?)]
[rename sql:count* count* (->* () ((or/c table-alias? select-alias?)) aggregate?)]
[rename sql:min min (-> quotable? aggregate?)]
[rename sql:max max (-> quotable? aggregate?)]
[rename sql:average average (-> quotable? aggregate?)]
[rename sql:inner inner (-> relation+select? relation+select? quotable? join?)]
[rename sql:left left (-> relation+select? relation+select? quotable? join?)]
[rename sql:right right (-> relation+select? relation+select? quotable? join?)]
[rename sql:outer outer (-> relation+select? relation+select? join?)]
[rename sql:in in (-> quotable? (or/c select? (listof quotable?)) function?)]
[rename sql:if if (->* (quotable? quotable?) (quotable?) function?)]
[rename sql:literal literal (-> quotable? literal?)]
[rename sql:null null (-> type? literal?)]
[rename sql:order order (-> quotable? (symbols 'asc 'desc) order?)]
[rename sql:asc asc (-> quotable? order?)]
[rename sql:desc desc (-> quotable? order?)])