(module infer mzscheme
(require "unify.ss" "type-comparison.ss" "type-rep.ss" "effect-rep.ss" "subtype.ss"
"planet-requires.ss" "tc-utils.ss" "union.ss"
(lib "trace.ss")
(lib "plt-match.ss")
(lib "list.ss"))
(require-galore)
(provide infer infer/list infer/list/vararg combine table:un exn:infer?)
(define-struct (exn:infer exn:fail) (s t))
(define-values (fail-sym exn:infer?)
(let ([sym (gensym)])
(values sym (lambda (s) (eq? s sym)))))
(define-syntax fail!
(syntax-rules ()
[(_ s t) (raise fail-sym)
(raise (make-exn:infer "inference failed" (current-continuation-marks) s t))
(error "inference failed" s t)]))
(define (alist->mapping vars) (table:alist->eq (map (lambda (x) (cons x 'fail)) vars)))
(define ((mk-infer f) s t vars)
(let ([mapping (alist->mapping vars)])
(with-handlers
([exn:infer? (lambda _ #f)])
(mapping->subst (f s t mapping 'co)))))
(define (mapping->subst x)
(define sexp (table:to-sexp x))
(define result (filter (lambda (x) (list? (cadr x))) sexp))
(map (lambda (x) (list (car x) (cadr (cadr x)))) result))
(define ((combine flag) s t)
(match (list s t)
[(list 'fail t) t]
[(list t 'fail) t]
[(list (list sf s) (list tf t))
(cond
[(and sf tf (type-equal? s t)) (list (if (eq? sf tf) sf 'both) s)] [(memq 'both (list sf tf)) (fail! s t)] [(and sf tf (not (eq? sf tf))) (fail! s t)] [else
(let ([flag (or sf tf flag)])
(cond
[(and (eq? 'co flag) (subtype s t)) (list 'co t)]
[(and (eq? 'co flag) (subtype t s)) (list 'co s)]
[(and (eq? 'contra flag) (subtype s t)) (list 'contra s)]
[(and (eq? 'contra flag) (subtype t s)) (list 'contra t)]
[else (fail! s t)]))])]))
(define ((table:un flag) a b) (table:union/value a b (combine flag)))
(define (infer/int/union ss ts mapping flag)
(unless (= (length ss) (length ts))
(fail! ss ts))
(let-values ([(ss* ts*)
(values (filter (lambda (se) (not (memq se ts))) ss)
(filter (lambda (te) (not (memq te ss))) ts))])
(let ([l (map (lambda (x y) (infer/int x y mapping flag)) ss* ts*)])
(foldl (table:un flag) (table:make-eq) l))))
(define (infer/int/list ss ts mapping flag)
(unless (= (length ss) (length ts))
(fail! ss ts))
(let ([l (map (lambda (x y) (infer/int x y mapping flag)) ss ts)])
(foldl (table:un flag) (table:make-eq) l)))
(define (infer/int/list/eff ss ts mapping flag)
(unless (= (length ss) (length ts))
(error 'bad ss ts)
(fail! ss ts))
(let ([l (map (lambda (x y) (infer/int/eff x y mapping flag)) ss ts)])
(foldl (table:un flag) (table:make-eq) l)))
(define (infer/int/list/vararg ss rest ts mapping flag)
(unless (<= (length ss) (length ts))
(fail! ss ts))
(let loop-types
([ss ss]
[ts ts]
[tbl mapping])
(cond [(null? ts) tbl]
[(and rest (null? ss))
(let ([tbl* (infer/int rest (car ts) tbl flag)])
(loop-types ss (cdr ts) tbl*))]
[else (let ([tbl* (infer/int (car ss) (car ts) tbl flag)])
(loop-types (cdr ss) (cdr ts) tbl*))])))
(define (infer/list/vararg ss rest ts vars)
(let ([mapping (alist->mapping vars)])
(with-handlers
([exn:infer? (lambda _ #f)])
(mapping->subst (infer/int/list/vararg ss rest ts mapping 'co)))))
(define (swap flag) (case flag
[(co) 'contra]
[(contra) 'co]
[else (int-err "bad flag: ~a" flag)]))
(define (co? x) (eq? x 'co))
(define (contra? x) (eq? x 'contra))
(define (infer/int/eff s t mapping flag)
(let ([fail! (case-lambda [() (fail! s t)]
[(s t) (fail! s t)])])
(parameterize ([match-equality-test type-equal?])
(match (list s t)
[(list t t) mapping]
[(list (Latent-Restrict-Effect: t1) (Latent-Restrict-Effect: t2)) (infer/int t1 t2 mapping flag)]
[(list (Latent-Remove-Effect: t1) (Latent-Remove-Effect: t2)) (infer/int t1 t2 mapping flag)]
))))
(define (infer/int s t mapping flag)
(let ([fail! (case-lambda [() (fail! s t)]
[(s t) (fail! s t)])])
(parameterize ([match-equality-test type-equal?])
(match (list s t)
[(list t t) mapping]
[(list (F: v) t)
(let ([cur (table:lookup v mapping)])
(match cur
['fail (table:insert v (list #f t) mapping)]
[#f (fail!)]
[(list cur-flag cur-t)
(cond
[(or (not cur-flag) (eq? flag cur-flag))
(cond
[(type-equal? cur-t t) mapping]
[(and (eq? flag 'co) (subtype cur-t t))
(table:insert v (list flag t) mapping)]
[(and (eq? flag 'co) (subtype t cur-t))
(table:insert v (list flag cur-t) mapping)]
[(eq? flag 'co)
(table:insert v (list flag (Un t cur-t)) mapping)]
[(and (eq? flag 'contra) (subtype t cur-t))
(table:insert v (list flag t) mapping)]
[(and (eq? flag 'contra) (subtype t cur-t))
(table:insert v (list flag cur-t) mapping)]
[else (int-err "bad flag value: ~a" flag)])]
[(type-equal? cur-t t)
(table:insert (list 'both cur-t) mapping)]
[else
(fail! cur-t t)])]))]
[(or (_ ($ dynamic)) (($ dynamic) _)) mapping]
[(list (Vector: s) (Vector: t)) (infer/int s t mapping flag)]
[(list (Pair: s1 s2) (Pair: t1 t2))
(infer/int/list (list s1 s2) (list t1 t2) mapping flag)]
[(list (Hashtable: s1 s2) (Hashtable: t1 t2))
(infer/int/list (list s1 s2) (list t1 t2) mapping flag)]
[(list (Struct: nm p flds proc) (Struct: nm p flds* proc*))
(infer/int/list (cons proc flds) (cons proc* flds*) mapping flag)]
[(list (Param: in1 out1) (Param: in2 out2))
(infer/int/list (list in1 out1) (list in2 out2) mapping flag)]
[(list (Mu-unsafe: s) (Mu-unsafe: t))
(infer/int s t mapping flag)]
[(list s (? Mu? t)) (infer/int s (unfold t) mapping flag)]
[(list (? Mu? s) t) (infer/int (unfold s) t mapping flag)]
[(list (Union: l1) (Union: l2))
(=> unmatch)
(unless (= (length l1) (length l2))
(unmatch))
(infer/int/union l1 l2 mapping flag)]
[(list (Function: (list (arr: ts t t-rest t-thn-eff t-els-eff) ...))
(Function: (list (arr: ss s s-rest s-thn-eff s-els-eff) ...)))
(=> unmatch)
(define (compatible-rest t-rest s-rest)
(andmap (lambda (x y) (or (and x y) (and (not x) (not y)))) t-rest s-rest))
(define (U a b) ((table:un flag) a b))
(let-values ([(s-thn-eff s-els-eff) (if (and (null? (car t-thn-eff)) (null? (cdr t-thn-eff))
(null? (car t-els-eff)) (null? (cdr t-els-eff)))
(values (list null) (list null))
(values s-thn-eff s-els-eff))])
(unless (and (= (length ts) (length ss))
(= (length t-thn-eff) (length s-thn-eff))
(= (length t-els-eff) (length s-els-eff))
(compatible-rest t-rest s-rest))
(unmatch))
(let ([arg-mapping (infer/int/list (apply append ts) (apply append ss) mapping (swap flag))]
[ret-mapping (infer/int/list t s mapping flag)]
[thn-mapping (infer/int/list/eff (apply append t-thn-eff) (apply append s-thn-eff) mapping flag)]
[els-mapping (infer/int/list/eff (apply append t-els-eff) (apply append s-els-eff) mapping flag)])
(U (U arg-mapping ret-mapping) (U thn-mapping els-mapping))))]
[(list (Function: ftys) (and t (Function: (list (arr: ss s s-rest s-thn-eff s-els-eff)))))
(=> unmatch)
(when (= 1 (length ftys)) (unmatch)) (or
(ormap
(lambda (fty)
(with-handlers
([exn:infer? (lambda _ #f)])
(infer/int (make-Function (list fty)) t mapping flag)))
ftys)
(fail!))]
[(list (and t (Function: (list (arr: ss s s-rest s-thn-eff s-els-eff)))) (Function: ftys))
(=> unmatch)
(when (= 1 (length ftys)) (unmatch)) (or
(ormap
(lambda (fty)
(with-handlers
([exn:infer? (lambda _ #f)])
(infer/int t (make-Function (list fty)) mapping flag)))
ftys)
(fail!))]
[(list (Union: e1) t)
(or
(ormap
(lambda (e)
(with-handlers
([exn:infer? (lambda _ #f)])
(infer/int e t mapping flag)))
e1)
(fail!))]
[else (cond [(and (co? flag) (subtype t s)) mapping]
[(and (contra? flag) (subtype s t)) mapping]
[else (fail!)])]
))))
(define infer (mk-infer infer/int))
(define infer/list (mk-infer infer/int/list))
)