#lang scheme/base
(require scheme/contract
scheme/match
srfi/19/time
srfi/26/cut
(planet untyped/unlib:3/list)
(planet untyped/unlib:3/symbol)
"../../base.ss"
"lang-internal.ss"
"struct.ss"
"type.ss")
(define type:boolean (make-boolean-type))
(define type:integer (make-integer-type))
(define type:real (make-real-type))
(define type:string (make-string-type #f))
(define type:symbol (make-symbol-type #f))
(define type:time-utc (make-time-utc-type))
(define type:time-tai (make-time-tai-type))
(define-function (sql:and . args)
[(andmap boolean-expression? args) type:boolean])
(define-function (sql:or . args)
[(andmap boolean-expression? args) type:boolean])
(define-function (sql:not arg)
[(boolean-expression? arg) type:boolean])
(define-function (sql:+ . args)
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real]
[(andmap time-tai-expression? args) type:time-tai]
[(andmap temporal-expression? args) type:time-utc])
(define-function (sql:- . args)
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real]
[(andmap time-tai-expression? args) type:time-tai]
[(andmap temporal-expression? args) type:time-utc])
(define-function (sql:* . args)
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real])
(define-function (sql:/ arg1 arg2)
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:real])
(define-function (sql:abs arg)
[(integer-expression? arg) type:integer]
[(numeric-expression? arg) type:real])
(define-function (sql:floor arg)
[(numeric-expression? arg) type:integer])
(define-function (sql:ceiling arg)
[(numeric-expression? arg) type:integer])
(define-function (sql:round arg)
[(numeric-expression? arg) type:integer])
(define-function (sql:like arg1 arg2)
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean])
(define-function (sql:regexp-match arg1 arg2)
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean])
(define-function (sql:regexp-match-ci arg1 arg2)
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean])
(define-function (sql:string-append . args)
[(andmap symbol-expression? args) type:symbol]
[(andmap character-expression? args) type:string])
(define-function (sql:string-replace arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace-ci arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace* arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace*-ci arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:= arg1 arg2)
[#t type:boolean])
(define-function (sql:<> arg1 arg2)
[#t type:boolean])
(define-function (sql:< arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:> arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:>= arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:<= arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:null? arg)
[#t type:boolean])
(define-function (sql:coalesce . args)
[(andmap boolean-expression? args) type:boolean]
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real]
[(andmap symbol-expression? args) type:symbol]
[(andmap string-expression? args) type:string]
[(andmap time-tai-expression? args) type:time-tai]
[(andmap temporal-expression? args) type:time-utc])
(define-function (sql:->string arg1 arg2)
[(string-expression? arg2) type:string])
(define-function (sql:->symbol arg1 arg2)
[(string-expression? arg2) type:symbol])
(define (sql:in arg1 arg2)
(let* ([arg1 (sql-lift arg1)]
[arg2 (if (list? arg2)
(map sql-lift arg2)
arg2)]
[type1 (expression-type arg1)]
[type2 (cond [(null? arg2) type1] [(pair? arg2) (let ([type2 (expression-type (car arg2))])
(unless (andmap (cut type-compatible? type2 <>)
(map expression-type arg2))
(raise-exn exn:fail:contract
(format "sql:in: list elements must all be of the same type: ~a" arg2)))
type2)]
[(select? arg2) (let ([columns (select-what arg2)])
(unless (= (length columns) 1)
(raise-exn exn:fail:contract
(format "sql:in: subselect must have exactly one column: ~a" arg2)))
(expression-type (car columns)))])])
(unless (type-compatible? type1 type2)
(raise-exn exn:fail:contract
(format "sql:in: type mismatch: argument types do not match: ~a ~a" type1 type2)))
(make-function type:boolean 'in (list arg1 arg2))))
(define sql:if
(case-lambda
[(test pos)
(define pos* (sql-lift pos))
(sql:if test pos* (sql:null (expression-type pos*)))]
[(test pos neg)
(let* ([test (sql-lift test)]
[pos (sql-lift pos)]
[neg (sql-lift neg)]
[type (cond [(and (boolean-expression? test) (andmap boolean-expression? (list pos neg))) type:boolean]
[(and (boolean-expression? test) (andmap integer-expression? (list pos neg))) type:integer]
[(and (boolean-expression? test) (andmap numeric-expression? (list pos neg))) type:real]
[(and (boolean-expression? test) (andmap symbol-expression? (list pos neg))) type:symbol]
[(and (boolean-expression? test) (andmap character-expression? (list pos neg))) type:string]
[(and (boolean-expression? test) (andmap time-tai-expression? (list pos neg))) type:time-tai]
[(and (boolean-expression? test) (andmap temporal-expression? (list pos neg))) type:time-utc]
[else (error (format "~a not defined for the types: ~s"
'id
(map expression-type (list test pos neg))))])])
(make-function type 'if (list test pos neg)))]))
(define-syntax sql:cond
(syntax-rules (else)
[(_ [test expr1])
(sql:if test expr1)]
[(_ [test expr1] [else expr2])
(sql:if test expr1 expr2)]
[(_ [test1 expr1] [test2 expr2] ...)
(sql:if test1 expr1 (sql:cond [test2 expr2] ...))]))
(define sql:literal make-literal)
(define (sql:null type)
(make-null-literal type))
(define (sql:order expr dir)
(make-order (sql-lift expr) dir))
(define sql:asc (cut sql:order <> 'asc))
(define sql:desc (cut sql:order <> 'desc))
(provide (rename-out [sql:alias alias]
[sql:cond cond]))
(provide/contract
[rename sql:select select (->* (#:from (or/c relation? select?))
(#:what (or/c expression? relation?
(listof (or/c expression? relation?))
false/c)
#:distinct (or/c expression? (listof expression?) boolean?)
#:where (or/c expression? false/c)
#:group (listof (or/c column? relation?))
#:order (listof order?)
#:having (or/c expression? false/c)
#:limit (or/c integer? false/c)
#:offset (or/c integer? false/c))
select?)]
[rename sql:select/internal select/internal (-> (or/c expression? relation? (listof (or/c expression? relation?)) false/c)
(or/c expression? (listof expression?) boolean?)
(or/c relation? select?)
(or/c expression? false/c)
(listof (or/c column? relation?))
(listof order?)
(or/c expression? false/c)
(or/c integer? false/c)
(or/c integer? false/c)
select?)]
[rename sql:count count (-> column-alias? aggregate?)]
[rename sql:count* count* (->* () ((or/c table-alias? select-alias?)) aggregate?)]
[rename sql:min min (-> quotable? aggregate?)]
[rename sql:max max (-> quotable? aggregate?)]
[rename sql:average average (-> quotable? aggregate?)]
[rename sql:inner inner (-> relation+select? relation+select? quotable? join?)]
[rename sql:left left (-> relation+select? relation+select? quotable? join?)]
[rename sql:right right (-> relation+select? relation+select? quotable? join?)]
[rename sql:outer outer (-> relation+select? relation+select? join?)]
[rename sql:in in (-> quotable? (or/c select? (listof quotable?)) function?)]
[rename sql:if if (->* (quotable? quotable?) (quotable?) function?)]
[rename sql:literal literal (-> quotable? literal?)]
[rename sql:null null (-> type? literal?)]
[rename sql:order order (-> quotable? (symbols 'asc 'desc) order?)]
[rename sql:asc asc (-> quotable? order?)]
[rename sql:desc desc (-> quotable? order?)])