compiler/hoist.ss
(module hoist mzscheme
  (require (planet "contract-utils.ss" ("cobbe" "contract-utils.plt" 1 0))
           (all-except (planet "list.ss" ("dherman" "list.plt" 1 0)) any)
           (lib "contract.ss")
           (lib "match.ss")
           "../syntax/ast.ss"
           "../syntax/token.ss")

  ;; TODO: abstract out the list functions

  (define-struct (FunctionDeclaration/hoisted FunctionDeclaration) (functions variables))
  (define-struct (FunctionExpression/hoisted FunctionExpression) (functions variables))
  (define-struct (LetExpression/hoisted LetExpression) (functions variables))

  ;; a (continuation a -> b) is:
  ;;
  ;;   - ((listof FunctionDeclaration/hoisted) (listof Identifier) a -> b)

  (define (continuation/c c)
    ((listof FunctionDeclaration/hoisted?) (listof Identifier?) c . -> . any))

  ;; an (a ->k b) is:
  ;;
  ;;   - (a (continuation b -> c) -> c)

  (define-syntax provide/contract/k
    (syntax-rules (->k)
      [(_ [name (domain . ->k . range)] ...)
       (provide/contract
         [name (domain (continuation/c range) . -> . any)]
         ...)]))

  ;; TODO: just directly use lset-difference instead of unique-vars at all
  (define (unique-vars funs vars)
    (lset-difference Identifier=?
      (delete-duplicates vars Identifier=?)
      (map FunctionDeclaration-name funs)))

  ;; ===========================================================================
  ;; TOP-LEVEL HOISTING FUNCTIONS
  ;; ===========================================================================

  ;; hoist-function-declaration : FunctionDeclaration -> FunctionDeclaration/hoisted
  (define (hoist-function-declaration decl)
    (match decl
      [($ FunctionDeclaration location name args body)
       (hoist-source-elements body
         (lambda (funs vars stmts)
           (make-FunctionDeclaration/hoisted location name args stmts funs (unique-vars funs vars))))]))

  ;; hoist-script : (listof SourceElement) -> (listof FunctionDeclaration/hoisted)
  ;;                                          (listof Identifier)
  ;;                                          (listof Statement)
  (define (hoist-script elts)
    (hoist-source-elements elts
      (lambda (funs vars stmts)
        (values funs (unique-vars funs vars) stmts))))

  ;; ===========================================================================
  ;; COMPOUND HOISTING FUNCTIONS
  ;; ===========================================================================

  ;; map-k : (a ->k b) (listof a) ->k (listof b)
  (define (map-k hoist1 elts k)
    (if (null? elts)
        (k null null null)
        (hoist1 (car elts)
          (lambda (funs1 vars1 result)
            (map-k hoist1 (cdr elts)
              (lambda (funs2 vars2 results)
                (k (append funs1 funs2)
                   (append vars1 vars2)
                   (cons result results))))))))

  ;; optional-map-k : (a ->k (optional b)) (listof a) ->k (listof b)
  (define (optional-map-k hoist1 elts k)
    (if (null? elts)
        (k null null null)
        (hoist1 (car elts)
          (lambda (funs1 vars1 result)
            (optional-map-k hoist1 (cdr elts)
              (lambda (funs2 vars2 results)
                (k (append funs1 funs2)
                   (append vars1 vars2)
                   (if result (cons result results) results))))))))
      
  ;; append-map-k : (a ->k (listof b)) (listof a) ->k (listof b)
  (define (append-map-k hoist1 elts k)
    (if (null? elts)
        (k null null null)
        (hoist1 (car elts)
          (lambda (funs1 vars1 results1)
            (append-map-k hoist1 (cdr elts)
              (lambda (funs2 vars2 results2)
                (k (append funs1 funs2)
                   (append vars1 vars2)
                   (append results1 results2))))))))

  ;; hoist-source-elements : (listof SourceElement) ->k (listof Statement)
  (define (hoist-source-elements elts k)
    (map-k hoist-source-element elts k))

  ;; hoist-expressions : (listof Expression) ->k (listof Expression)
  (define (hoist-expressions exprs k)
    (map-k hoist-expression exprs k))

  ;; hoist-optional-expression : (optional Expression) ->k (optional Expression)
  (define (hoist-optional-expression expr k)
    (if (not expr)
        (k null null #f)
        (hoist-expression expr k)))

  ;; hoist-optional-expressions : (listof (optional Expression)) ->k (listof (optional Expression))
  (define (hoist-optional-expressions exprs k)
    (map-k hoist-optional-expression exprs k))

  ;; hoist-substatements : (listof SourceElement) ->k (listof Statement)
  (define (hoist-substatements stmts k)
    (append-map-k hoist-substatement stmts k))

  ;; hoist-variable-declarations : (listof VariableDeclaration) ->k (listof Expression)
  (define (hoist-variable-declarations decls k)
    (optional-map-k hoist-variable-declaration decls k))

  ;; hoist-case-clauses : (listof CaseClause) ->k (listof CaseClause)
  (define (hoist-case-clauses cases k)
    (map-k hoist-case-clause cases k))

  ;; hoist-catch-clauses : (listof CatchClause) ->k (listof CatchClause)
  (define (hoist-catch-clauses catches k)
    (map-k hoist-catch-clause catches k))

  ;; ===========================================================================
  ;; CORE HOISTING FUNCTIONS
  ;; ===========================================================================

  ;; hoist-source-element : SourceElement ->k Statement
  (define (hoist-source-element src0 k)
    (if (FunctionDeclaration? src0)
        (k (list (hoist-function-declaration src0))
           null
           (make-EmptyStatement (Term-location src0)))
        (hoist-statement src0
          (lambda (funs vars src)
            (k funs vars (statements->statement src src0))))))

  ;; hoist-substatement : SourceElement ->k (listof Statement)
  (define (hoist-substatement src0 k)
    (if (FunctionDeclaration? src0)
        (k (list (hoist-function-declaration src0)) null null)
        (hoist-statement src0 k)))

  ;; hoist-variable-declaration : VariableDeclaration ->k (optional Expression)
  (define (hoist-variable-declaration decl k)
    (match decl
      [($ VariableDeclaration loc id #f)
       (k null (list id) #f)]
      [($ VariableDeclaration loc id init)
       (hoist-expression init
         (lambda (funs vars init)
           (k funs
              (cons id vars)
              (make-AssignmentExpression loc (make-VarReference (Term-location id) id) '= init))))]))

  ;; hoist-case-clause : CaseClause ->k CaseClause
  (define (hoist-case-clause case k)
    (match case
      [($ CaseClause loc #f answer)
       (hoist-substatements answer
         (lambda (funs vars answer)
           (k funs vars (make-CaseClause loc #f answer))))]
      [($ CaseClause loc question answer)
       (hoist-expression question
         (lambda (funs1 vars1 question)
           (hoist-substatements answer
             (lambda (funs2 vars2 answer)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (make-CaseClause loc question answer))))))]))

  ;; hoist-catch-clause : CatchClause ->k CatchClause
  (define (hoist-catch-clause catch k)
    (match catch
      [($ CatchClause loc id body0)
       (hoist-statement body0
         (lambda (funs vars body)
           (k funs vars (make-CatchClause loc id (statements->statement body body0)))))]))

  (define (to-location x)
    (cond
      [(not x) #f]
      [(position? x) x]
      [else (ast-location x)]))

  ;; statements->statement : (listof Statement) (optional has-location) -> Statement
  (define (statements->statement ls loc)
    (cond
      [(null? ls) (make-EmptyStatement (to-location loc))]
      [(null? (cdr ls)) (car ls)]
      [else (make-BlockStatement (to-location loc) ls)]))

  ;; hoist-statement : Statement ->k (listof Statement)
  (define (hoist-statement stmt k)
    (match stmt
      [($ BlockStatement loc stmts)
       (hoist-substatements stmts
         (lambda (funs vars stmts)
           (k funs vars (list (make-BlockStatement loc stmts)))))]
      [($ VariableStatement loc decls)
       (hoist-variable-declarations decls
         (lambda (funs vars exprs)
           (k funs vars (map (lambda (expr)
                               (make-ExpressionStatement (Term-location expr) expr))
                             exprs))))]
      [($ ExpressionStatement loc expr)
       (hoist-expression expr
         (lambda (funs vars expr)
           (k funs vars (list (make-ExpressionStatement loc expr)))))]
      [($ IfStatement loc test consequent0 alternate0)
       (hoist-expression test
         (lambda (funs1 vars1 test)
           (hoist-substatement consequent0
             (lambda (funs2 vars2 consequent)
               (if alternate0
                   (hoist-substatement alternate0
                     (lambda (funs3 vars3 alternate)
                       (k (append funs1 funs2 funs3)
                          (append vars1 vars2 vars3)
                          (list (make-IfStatement loc
                                                  test
                                                  (statements->statement consequent consequent0)
                                                  (statements->statement alternate alternate0))))))
                   (k (append funs1 funs2)
                      (append vars1 vars2)
                      (list (make-IfStatement loc test (statements->statement consequent consequent0) #f))))))))]
      [($ DoWhileStatement loc body0 test)
       (hoist-substatement body0
         (lambda (funs1 vars1 body)
           (hoist-expression test
             (lambda (funs2 vars2 test)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (list (make-DoWhileStatement loc (statements->statement body body0) test)))))))]
      [($ WhileStatement loc test body0)
       (hoist-expression test
         (lambda (funs1 vars1 test)
           (hoist-substatement body0
             (lambda (funs2 vars2 body)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (list (make-WhileStatement loc test (statements->statement body body0))))))))]
      [($ ForStatement loc init test incr body0)
       (let ([hoist (if (or (null? init) (Expression? (car init)))
                        hoist-expressions
                        hoist-variable-declarations)])
         (hoist init
           (lambda (funs1 vars1 init)
             (hoist-optional-expression test
               (lambda (funs2 vars2 test)
                 (hoist-expressions incr
                   (lambda (funs3 vars3 incr)
                     (hoist-substatement body0
                       (lambda (funs4 vars4 body)
                         (k (append funs1 funs2 funs3 funs4)
                            (append vars1 vars2 vars3 vars4)
                            (list (make-ForStatement loc init test incr (statements->statement body body0)))))))))))))]
      [($ ForInStatement loc (and lhs (? Expression?)) container body0)
       (hoist-expression lhs
         (lambda (funs1 vars1 lhs)
           (hoist-expression container
             (lambda (funs2 vars2 container)
               (hoist-substatement body0
                 (lambda (funs3 vars3 body)
                   (k (append funs1 funs2 funs3)
                      (append vars1 vars2 vars3)
                      (list (make-ForInStatement loc lhs container (statements->statement body body0))))))))))]
      [($ ForInStatement loc ($ VariableDeclaration v-loc id #f) container body0)
       (hoist-expression container
         (lambda (funs1 vars1 container)
           (hoist-substatement body0
             (lambda (funs2 vars2 body)
               (k (append funs1 funs2)
                  (cons id (append vars1 vars2))
                  (list (make-ForInStatement loc
                                             (make-VarReference v-loc id)
                                             container
                                             (statements->statement body body0))))))))]
      [($ ReturnStatement loc (and expr (? Expression?)))
       (hoist-expression expr
         (lambda (funs vars expr)
           (k funs vars (list (make-ReturnStatement loc expr)))))]
      [($ WithStatement loc test body0)
       (hoist-expression test
         (lambda (funs1 vars1 test)
           (hoist-substatement body0
             (lambda (funs2 vars2 body)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (list (make-WithStatement loc test (statements->statement body body0))))))))]
      [($ SwitchStatement loc expr cases)
       (hoist-expression expr
         (lambda (funs1 vars1 expr)
           (hoist-case-clauses cases
             (lambda (funs2 vars2 cases)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (list (make-SwitchStatement loc expr cases)))))))]
      [($ LabelledStatement loc label stmt0)
       (hoist-substatement stmt0
         (lambda (funs vars stmt)
           (k funs vars (list (make-LabelledStatement loc label (statements->statement stmt stmt0))))))]
      [($ ThrowStatement loc expr)
       (hoist-expression expr
         (lambda (funs vars expr)
           (k funs vars (list (make-ThrowStatement loc expr)))))]
      [($ TryStatement loc body0 catch0 finally0)
       (hoist-statement body0
         (lambda (funs1 vars1 body)
           (hoist-catch-clauses catch0
             (lambda (funs2 vars2 catch)
               (if finally0
                   (hoist-statement finally0
                     (lambda (funs3 vars3 finally)
                       (k (append funs1 funs2 funs3)
                          (append vars1 vars2 vars3)
                          (list (make-TryStatement loc
                                                   (statements->statement body body0)
                                                   catch
                                                   (statements->statement finally finally0))))))
                   (k (append funs1 funs2)
                      (append vars1 vars2)
                      (list (make-TryStatement loc (statements->statement body body0) catch #f))))))))]
      [_ (k null null (list stmt))]))

  ;; hoist-expression : Expression ->k Expression
  (define (hoist-expression expr k)
    (match expr
      [($ ArrayLiteral loc elts)
       (hoist-optional-expressions elts
         (lambda (funs vars elts)
           (k funs vars (make-ArrayLiteral loc elts))))]
      [($ ObjectLiteral loc ([props . vals] ...))
       (hoist-expressions vals
         (lambda (funs vars vals)
           (k funs vars (make-ObjectLiteral loc (map cons props vals)))))]
      [($ BracketReference loc container key)
       (hoist-expression container
         (lambda (funs1 vars1 container)
           (hoist-expression key
             (lambda (funs2 vars2 key)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (make-BracketReference loc container key))))))]
      [($ DotReference loc container id)
       (hoist-expression container
         (lambda (funs vars container)
           (k funs vars (make-DotReference loc container id))))]
      [($ NewExpression loc constructor args)
       (hoist-expression constructor
         (lambda (funs1 vars1 constructor)
           (hoist-expressions args
             (lambda (funs2 vars2 args)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (make-NewExpression loc constructor args))))))]
      [($ PostfixExpression loc expr op)
       (hoist-expression expr
         (lambda (funs vars expr)
           (k funs vars (make-PostfixExpression loc expr op))))]
      [($ PrefixExpression loc op expr)
       (hoist-expression expr
         (lambda (funs vars expr)
           (k funs vars (make-PrefixExpression loc op expr))))]
      [($ InfixExpression loc left op right)
       (hoist-expression left
         (lambda (funs1 vars1 left)
           (hoist-expression right
             (lambda (funs2 vars2 right)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (make-InfixExpression loc left op right))))))]
      [($ ConditionalExpression loc test consequent alternate)
       (hoist-expression test
         (lambda (funs1 vars1 test)
           (hoist-expression consequent
             (lambda (funs2 vars2 consequent)
               (hoist-expression alternate
                 (lambda (funs3 vars3 alternate)
                   (k (append funs1 funs2 funs3)
                      (append vars1 vars2 vars3)
                      (make-ConditionalExpression loc test consequent alternate))))))))]
      [($ AssignmentExpression loc left op right)
       (hoist-expression left
         (lambda (funs1 vars1 left)
           (hoist-expression right
             (lambda (funs2 vars2 right)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (make-AssignmentExpression loc left op right))))))]
      [($ FunctionExpression loc name args body)
       (k null null (hoist-source-elements body
                      (lambda (funs vars body)
                        (make-FunctionExpression/hoisted loc name args body funs (unique-vars funs vars)))))]
      [($ LetExpression loc bindings body)
       (k null null (hoist-source-element body
                      (lambda (funs vars body)
                        (make-LetExpression/hoisted loc bindings body funs (unique-vars funs vars)))))]
      [($ CallExpression loc method args)
       (hoist-expression method
         (lambda (funs1 vars1 method)
           (hoist-expressions args
             (lambda (funs2 vars2 args)
               (k (append funs1 funs2)
                  (append vars1 vars2)
                  (make-CallExpression loc method args))))))]
      [($ ParenExpression loc expr)
       (hoist-expression expr
         (lambda (funs vars expr)
           (k funs vars (make-ParenExpression loc expr))))]
      [_ (k null null expr)]))

  (provide/contract
    [continuation/c ((union flat-contract? predicate/c) . -> . contract?)])

  (provide/contract/k
    [hoist-source-elements ((listof SourceElement?) . ->k . (listof Statement?))]
    [hoist-expressions ((listof Expression?) . ->k . (listof Expression?))]
    [hoist-optional-expressions ((listof (optional/c Expression?)) . ->k . (listof (optional/c Expression?)))]
    [hoist-source-element (SourceElement? . ->k . Statement?)]
    [hoist-statement (Statement? . ->k . Statement?)]
    [hoist-expression (Expression? . ->k . Expression?)]
    [hoist-substatement (SubStatement? . ->k . (listof Statement?))]
    [hoist-substatements ((listof SubStatement?) . ->k . (listof Statement?))]
    [hoist-variable-declaration (VariableDeclaration? . ->k . Expression?)]
    [hoist-variable-declarations ((listof VariableDeclaration?) . ->k . (listof Expression?))]
    [hoist-case-clause (CaseClause? . ->k . CaseClause?)]
    [hoist-case-clauses ((listof CaseClause?) . ->k . (listof CaseClause?))]
    [hoist-catch-clause (CatchClause? . ->k . CatchClause?)]
    [hoist-catch-clauses ((listof CatchClause?) . ->k . (listof CatchClause?))])

  (provide/contract
    [hoist-function-declaration (FunctionDeclaration? . -> . FunctionDeclaration/hoisted?)]
    [hoist-script ((listof SourceElement?) . -> . (values (listof FunctionDeclaration/hoisted?)
                                                          (listof Identifier?)
                                                          (listof Statement?)))])

  (provide/contract
    (struct (FunctionDeclaration/hoisted FunctionDeclaration) ([location (optional/c region?)]
                                                               [name Identifier?]
                                                               [args (listof Identifier?)]
                                                               [body (listof Statement?)]
                                                               [functions (listof FunctionDeclaration/hoisted?)]
                                                               [vars (listof Identifier?)]))
    (struct (FunctionExpression/hoisted FunctionExpression) ([location (optional/c region?)]
                                                             [name (optional/c Identifier?)]
                                                             [args (listof Identifier?)]
                                                             [body (listof Statement?)]
                                                             [functions (listof FunctionDeclaration/hoisted?)]
                                                             [vars (listof Identifier?)]))
    (struct (LetExpression/hoisted LetExpression) ([location (optional/c region?)]
                                                   [bindings (listof VariableDeclaration?)]
                                                   [body Statement?]
                                                   [functions (listof FunctionDeclaration/hoisted?)]
                                                   [vars (listof Identifier?)]))))