private/check-program.ss
(module check-program mzscheme
  
  (require (prefix kernel: (lib "kerncase.ss" "syntax"))
           (lib "contract.ss")
           (lib "list.ss")
           "arity-table.ss"
           (lib "my-macros.ss" "stepper" "private")
           (lib "shared.ss" "stepper" "private"))
  
  (provide/contract [check-program (-> syntax? table? (listof (cons/c symbol? (cons/c syntax? any/c))))])

  (define (equal-with-anys? a b)
    (cond [(eq? a 'any) #t]
          [(eq? b 'any) #t]
          [(and (cons? a) (cons? b))
           (and (equal-with-anys? (car a) (car b))
                (equal-with-anys? (cdr a) (cdr b)))]
          [else (equal? a b)]))
  
  (define (test desired p . args)
    (let* ([result (apply p args)])
      (unless (equal-with-anys? desired result)
        (fprintf (current-error-port) "test failed: desired: ~v\ngot: ~v\ntest: ~v\n" desired result (cons p args)))))
  
  
  (define (check-program stx table) 
    (result-error-messages (top-result-result (top-level-expr-iterator stx table))))
  
  ;; this analysis computes two things simultaneausly: error messages and ids occurring
  
  (define-struct result (error-messages ids-occurring))
  
  ; the empty-result constant.
  (define empty-result (make-result null null))
  
  ; given just one of the three fields, make a result for it.
  (define (make-err-result errors) (make-result errors null))
  (define (make-id-result ids) (make-result null ids))
  
  ;; combine: result? result? ->
  ; combine: [a,b,c] + [d,e,f] = [a+d,b+e,c+f]
  (define (combine result1 result2)
    (make-result (append (result-error-messages result1)
                         (result-error-messages result2))
                 (varref-set-union (list (result-ids-occurring result1)
                                         (result-ids-occurring result2)))))

  ; a top-result contains a result, a list of defined ids, and a list
  (define-struct top-result (result defined-ids all-defined-except))
  
  (define empty-top-result (make-top-result empty-result null null))
  
  (define (make-regular-top-result result) (make-top-result result null null))
  (define (make-defined-id-result ids) (make-top-result empty-result ids null))
  (define (make-all-defined-except-result a-d-e) (make-top-result empty-result null a-d-e))
  
  ;; combine-top: top-result? top-result? -> top-result?
  (define (combine-top result1 result2)
    (make-top-result (combine (top-result-result result1)
                              (top-result-result result2))
                     (varref-set-union (list (top-result-defined-ids result1)
                                             (top-result-defined-ids result2)))
                     (append (top-result-all-defined-except result1)
                             (top-result-all-defined-except result2))))
  
  ; check-ids-used : bindings kind syntax? result? -> result?
  ;  make sure that the bindings named by 'sought' are in the list of ids-occurring in the result
  (define (check-ids-used sought kind stx result)
    (combine 
     result
     (let ([not-used (remove* (result-ids-occurring result) sought module-identifier=?)])
       (if (null? not-used)
           empty-result
           (make-err-result `((unused-bindings ,stx ,kind ,not-used)))))))
  
  
  (define (scan-provide-spec provide-spec)
    (kernel:kernel-syntax-case provide-spec #f
      [id
       (identifier? #`id)
       (make-regular-top-result (make-id-result (list #`id)))]
      [(struct struct-identifier (field ...))
       ; just punt on this for now... there's probably a nice utility function that does this right
       empty-top-result]
      [(all-from module-name)
       empty-top-result]
      [(all-from-except module-name id ...)
       empty-top-result]
      [(all-defined)
       (make-all-defined-except-result null)]
      [(all-defined-except id ...)
       (make-all-defined-except-result (syntax->list #`(id ...)))]
      [else
       (error 'scan-provide-spec "unknown provide-spec: ~v" (syntax-object->datum provide-spec))]))

  (define (module-level-checks stx top-result)
    (combine-top
     top-result
     (let ([unused-defines 
            (foldl
             (lambda (unused-defines a-d-e-list)
               (binding-set-varref-set-intersect unused-defines a-d-e-list))
             (remove* (result-ids-occurring (top-result-result top-result)) (top-result-defined-ids top-result)
                      module-identifier=?)
             (top-result-all-defined-except top-result))])
       (if (null? unused-defines)
           empty-top-result
           (make-regular-top-result (make-err-result `((unused-bindings ,stx module ,unused-defines))))))))
  
  
  
  ; arglist-bindings : (syntax? -> (listof syntax?))
  ;  return a list of the names in the arglist
  
  (define (arglist-bindings arglist-stx)
    (syntax-case arglist-stx ()
      [var
       (identifier? arglist-stx)
       (list arglist-stx)]
      [(var ...)
       (syntax->list arglist-stx)]
      [(var . others)
       (cons #'var (arglist-bindings #'others))]))

  

  ;; TEMPLATE FUNCTIONS:
  ;;  these functions' definitions follow the data definitions presented in the Syntax
  ;;  chapter of the MzScheme Manual.
  
  ;; top-level-expr-iterator : syntax? table? -> top-level-result?
  
  (define (top-level-expr-iterator stx table)
    (kernel:kernel-syntax-case stx #f
        [(module identifier name (#%plain-module-begin . module-level-exprs))
         (module-level-checks
          stx
          (foldl combine-top empty-top-result (map (lambda (expr) (module-level-expr-iterator expr table)) (syntax->list #'module-level-exprs))))]
        [(begin . exps)
         (foldl combine-top empty-top-result (map (lambda (expr) (top-level-expr-iterator expr table)) (syntax->list #`exps)))]
        [else-stx
         (general-top-level-expr-iterator #t stx table)]))

  (define (module-level-expr-iterator stx table)
    (kernel:kernel-syntax-case stx #f
      [(provide . provide-specs)
       (foldl combine-top empty-top-result (map scan-provide-spec (syntax->list #`provide-specs)))]
      [(begin . exps)
       (foldl combine-top empty-top-result (map (lambda (expr) (module-level-expr-iterator expr table)) (syntax->list #'exps)))]
      [else-stx
       (general-top-level-expr-iterator #f stx table)]))
  
  
  (define (general-top-level-expr-iterator really-top-level? stx table)
    (kernel:kernel-syntax-case stx #f
        [(define-values (var ...) expr)
         (combine-top
          (make-defined-id-result (syntax->list #`(var ...)))
          (make-regular-top-result (expr-iterator #'expr table)))]
        [(define-syntaxes (var ...) expr)
         empty-top-result]
        [(begin . top-level-exprs)
         (foldl combine-top empty-top-result (map (lambda (expr) (top-level-expr-iterator expr table)) (syntax->list #'top-level-exprs)))]
        [(require . require-specs)
         empty-top-result]
        [(require-for-syntax . require-specs)
         empty-top-result]
        [else
         (make-regular-top-result (expr-iterator stx table))]))
 
  ; expr-iterator : syntax? table? -> result?
  (define (expr-iterator stx table)
    (let* ([recur (lambda (expr) (expr-iterator expr table))]
           [recur-on-pieces (lambda (exprs-stx) (foldl combine empty-result (map recur (syntax->list exprs-stx))))]
           [lambda-clause-abstraction
            (lambda (clause)
              (kernel:kernel-syntax-case clause #f
                [(arglist . bodies)
                 (check-ids-used (arglist-bindings #`arglist)
                                 'lambda
                                 stx
                                 (recur-on-pieces #'bodies))]
                [else
                 (error 'expr-syntax-object-iterator 
                        "unexpected (case-)lambda clause: ~a" 
                        (syntax-object->datum stx))]))]
           [let-values-abstraction
            (lambda (stx)
              (kernel:kernel-syntax-case stx #f
                [(kwd (((variable ...) rhs) ...) body ...)
                 ; note: because of the magic of free-identifier=?, we don't need to differentiate
                 ; between let & letrec here:
                 (check-ids-used (apply append (map syntax->list (syntax->list #`((variable ...) ...))))
                                 'let/rec
                                 stx
                                 (recur-on-pieces #'(rhs ... body ...)))]
                [else
                 (error 'expr-syntax-object-iterator 
                        "unexpected let(rec) expression: ~a"
                        stx
                        ;(syntax-object->datum stx)
                        )]))]) 
         (kernel:kernel-syntax-case stx #f
           [var-stx
            (identifier? (syntax var-stx))
            (make-id-result (list #`var-stx))]
           [(lambda . clause)
            (lambda-clause-abstraction #'clause)]
           [(case-lambda . clauses)
            (foldl combine empty-result (map lambda-clause-abstraction (syntax->list #'clauses)))]
           [(if test then)
            (recur-on-pieces #'(test then))]
           [(if test then else)
            (recur-on-pieces #'(test then else))]
           [(begin . bodies)
            (recur-on-pieces #'bodies)]
           [(begin0 . bodies)
            (recur-on-pieces #'bodies)]
           [(let-values . _)
            (let-values-abstraction stx)]
           [(letrec-values . _)
            (let-values-abstraction stx)]
           [(set! var val)
            (recur-on-pieces #'(val))]
           [(quote _)
            empty-result]
           [(quote-syntax _)
            empty-result]
           [(with-continuation-mark key mark body)
            (foldl combine empty-result
                   (list (recur #'key)
                         (recur #'mark)
                         (recur #'body)))]
           [(#%app . exprs)
            (let* ([expr-list (syntax->list #'exprs)])
              (combine
               (foldl combine empty-result (map recur expr-list))
               (if (null? expr-list)
                   empty-result
                   (let* ([fn-pos (car expr-list)])
                     (cond [(syntax-case (car expr-list) (#%top)
                              [var
                               (identifier? #'var)
                               #'var]
                              [(#%top . var)
                               (identifier? #'var)
                               #'var] 
                              [else #f])
                            =>
                            (lambda (var)
                              (let* ([match (find-match var table)])
                                (if match
                                    (if (arity-match (cadr match) (length (cdr expr-list)))
                                        (make-err-result `((application-ok ,stx)))
                                        (make-err-result `((bad-application ,stx ,(cadr match)))))
                                    (make-err-result `((unknown-id-application ,stx))))))]
                           [else
                            (make-err-result `((non-id-application ,stx)))])))))]
           [(#%datum . _)
            empty-result]
           [(#%top . var)
            (make-id-result (list #`var))]
           [else
            (error 'expr-iterator "unknown expr: ~a" 
                   (syntax-object->datum stx))])))
 
  

  ;; not bad testing
  
;  (define a-id (expand #'a))
;  (define a-id-stripped (syntax-case a-id (#%top)
;                          [(#%top . a)
;                           #'a]))
;  (define b-id (expand #'b))
;  (define b-id-stripped (syntax-case b-id (#%top)
;                          [(#%top . b)
;                           #'b]))
;  (define c-id (expand #'c))
;  (define c-id-stripped (syntax-case c-id (#%top)
;                          [(#%top . c)
;                           #'c]))
;  (define id-list (list a-id-stripped b-id-stripped c-id-stripped))
;  (define arities-list `(((2 2)) ((1 3) (5 inf)) ((3 3))))
; 
;  (define arity-table
;    (map list id-list arities-list))
; 
;  (define (check-program-test expected stx)
;    (test expected
;          make-testable
;          (check-program (expand stx) arity-table)))
; 
;  (define (make-testable result)
;    (map (lx (cons (car _) (cons (syntax-object->datum (cadr _)) (map (lx (if (pair? _)
;                                                                              (map (lx (if (syntax? _)
;                                                                                           (syntax-object->datum _)
;                                                                                           _))
;                                                                                   _)
;                                                                              _))
;                                                                      (cddr _)))))
;         result))
; 
;  (define d1 `(#%datum . 1))
;  (define d2 `(#%datum . 2))
;  (define d3 `(#%datum . 3))
;  (define d4 `(#%datum . 4))
;  (define d5 `(#%datum . 5))
;  (define ta `(#%top . a))
;  (define tb `(#%top . b))
;  (define tc `(#%top . c))
; 
;  (check-program-test `((application-ok (#%app ,ta ,d3 ,d4))) `(,a-id 3 4))
;  (check-program-test `((bad-application (#%app ,ta ,d3) ((2 2)))) `(,a-id 3))
;  (check-program-test `((unknown-id-application (#%app (#%top . f) ,d3))) `(f 3))
;  (check-program-test `((unknown-id-application (#%app (#%top . +) x ,d1))
;                        (non-id-application (#%app (lambda (x) (#%app (#%top . +) x ,d1)) ,d3)))
;                      `((lambda (x) (+ x 1)) 3))
;  (check-program-test `((application-ok (#%app ,tc ,d1 ,d2 ,d3))
;                        (bad-application (#%app ,tc ,d1 ,d2) ((3 3)))
;                        (non-id-application (#%app (#%app ,tc ,d1 ,d2) ,d3))
;                        (application-ok (#%app ,tb ,d1 ,d2 ,d3 ,d4 ,d5)))
;                      `(if (,b-id 1 2 3 4 5) ((,c-id 1 2) 3) (,c-id 1 2 3)))
;  (check-program-test `() `(begin 3 4))
;  (let* ([stx (expand `(module foo mzscheme (define (h x) (h x))))]
;         [id (syntax-case stx (#%plain-module define-values)
;               [(module dc1 dc2 (#%plain-module-begin dc4 (define-values (id) . dc3)))
;                #'id])]
;         [table `((,id ((1 1))))])
;    (test `((application-ok (#%app h x)))
;          make-testable (check-program stx table)))
; 
;  (check-program-test `((unused-bindings any lambda (x)))
;                      `(if 3 (lambda (x) 4) 8))
; 
;  (check-program-test `((unknown-id-application any)
;                        (unused-bindings any let/rec (z)))
;                      `(lambda (y) (let ([z 3] [q y]) (+ q y))))
; 
;  (check-program-test `((unused-bindings any module (q a b)))
;                      `(module foo mzscheme
;                     (provide (all-defined-except a b c d)
;                              c)
;                     (define z 1)
;                     (define b 13)
;                     (define a d)
;                     (define c 287)
;                     (define d 9)
;                     (define q z)))
 
  )