src/compiler/transform/anormalize.ss
#lang s-exp "../lang.ss"

(require "elim-anon.ss")
(require "../toplevel.ss")
(require "../env.ss")

(define temp-begin "temp")
(define higher-order-prims '(andmap argmax argmin build-list build-string compose
                             filter foldl foldr map memf ormap quicksort sort))
(define first-order-prims (foldl (lambda (symb env)
                                   (env-remove env symb))
                                 (env-extend-constant toplevel-env 'quote "primitive")
                                 higher-order-prims))

(define-struct linfo (return raise gensym))

;; get-struct-defs: (listof s-expr) -> (listof s-expr)
;; takes a list of toplevel statements (a program)
;; returns all struct definitions appearing at toplevel
(define (get-struct-defs program)
  (filter (lambda (statement) (and (cons? statement)
                                   (equal? (first statement) 'define-struct)))
          program))

;; get-struct-procs: s-expr -> (listof symbol)
;; consumes a struct definition in abstract syntax
;; returns a list of procs generated by defining that struct
(define (get-struct-procs struct-def)
  (list* (string->symbol (string-append "make-" (symbol->string (second struct-def))))
         (string->symbol (string-append (symbol->string (second struct-def)) "?"))
         (map (lambda (elt)
                (string->symbol (string-append (symbol->string (second struct-def))
                                               "-"
                                               (symbol->string elt))))
              (third struct-def))))

;; generate-prims: (listof s-expr) -> env
;; consumes a list of toplevel statements (a program)
;; returns an environment containing all first-order primitives for that program
;;    these are the predefined first-order primitives and struct primitives
(define (generate-prims program)
  (foldl (lambda (struct-def env)
           (foldl (lambda (struct-proc an-env)
                    (env-extend-constant an-env struct-proc "struct-prim"))
                  env
                  (get-struct-procs struct-def)))
         first-order-prims
         (get-struct-defs program)))

;; primitive?: any/c env -> boolean
;; returns true if the input is defined in the environment or an atomic datum
;;    false otherwise (cons or a symbol not representing a first-order primitive)
(define (primitive? dat prims)
  (and (not (cons? dat))
       (or (not (symbol? dat))
           (env-contains? prims dat))))

;; get-temp-symbol: number -> symbol
;; takes a gensym counter and returns a symbol for temporary binding using that gensym
(define (get-temp-symbol gensym)
  (string->symbol (string-append temp-begin (number->string gensym))))

;; fold-anormal-help: (listof s-expr) number -> linfo
;; folds anormal-help across a list of symbolic expressions
(define (fold-anormal-help expr prims gensym)
  (foldl (lambda (an-expr rest-info)
           (local [(define rec-info (anormal-help an-expr prims (linfo-gensym rest-info)))]
             (make-linfo (append (linfo-return rest-info)
                                 (list (linfo-return rec-info)))
                         (append (linfo-raise rest-info)
                                 (linfo-raise rec-info))
                         (linfo-gensym rec-info))))
         (make-linfo empty empty gensym)
         expr))

;; anormal-help: s-expr env number -> linfo
;; consumes a symbolic expression that is the output of ready-anormalize and a gensym counter
;; returns a symantically equivalent program in a-normal form
(define (anormal-help expr prims gensym)
  (cond
    [(cons? expr)
     (cond
       [(equal? (first expr) 'define)
        (local [(define arg-info (anormal-help (third expr) prims gensym))]
        ;(begin
          #;(printf "arg-info return is:\n ~a\n" (linfo-return arg-info))
          (make-linfo (list 'define
                            (second expr)
                            (if (empty? (linfo-raise arg-info))
                                (linfo-return arg-info)
                                (list 'local
                                      (linfo-raise arg-info)
                                      (linfo-return arg-info))))
                      empty
                      (linfo-gensym arg-info)))]
       [(equal? (first expr) 'local)
        (local [(define definitions (make-anormal (second expr) prims gensym))
                (define arg-info (anormal-help (third expr)
                                               prims
                                               (linfo-gensym definitions)))]
          ;(begin
            #;(printf "in local, second of arg-info is:\n ~a\n"
                    (second (linfo-return arg-info)))
          (make-linfo (list 'local
                            (append (linfo-return definitions)
                                    (linfo-raise arg-info))
                            (linfo-return arg-info))
                      empty
                      (linfo-gensym arg-info)))]
       [(equal? (first expr) 'cond)
        (local [(define anormal-cases
                  (foldl (lambda (case rest-cases)
                           (local [(define condition
                                     (if (equal? (first case) 'else)
                                         (make-linfo 'else empty (linfo-gensym rest-cases))
                                         (make-anormal (first case)
                                                       prims
                                                       (linfo-gensym rest-cases))))
                                   (define body
                                     (make-anormal (list (second case))
                                                   prims
                                                   (linfo-gensym condition)))]
                             (make-linfo (cons (list (linfo-return condition)
                                                     (first (linfo-return body)))
                                               (linfo-return rest-cases))
                                         empty
                                         (linfo-gensym body))))
                         (make-linfo empty empty gensym)
                         (rest expr)))]
          (make-linfo (cons 'cond (reverse (linfo-return anormal-cases)))
                      empty
                      (linfo-gensym anormal-cases)))]
       [(equal? (first expr) 'if)
        (local [(define condition (anormal-help (second expr) prims gensym))
                (define then-clause (make-anormal (list (third expr))
                                                  prims
                                                  (linfo-gensym condition)))
                (define else-clause (make-anormal (list (fourth expr))
                                                  prims
                                                  (linfo-gensym then-clause)))]
          (make-linfo (list 'if
                            (linfo-return condition)
                            (first (linfo-return then-clause))
                            (first (linfo-return else-clause)))
                      (linfo-raise condition)
                      (linfo-gensym else-clause)))]
       [(or (equal? (first expr) 'and)
            (equal? (first expr) 'or))
        (local [(define options (make-anormal (rest expr) prims gensym))]
          (make-linfo (cons (first expr) (linfo-return options))
                      empty
                      (linfo-gensym options)))]
       [(equal? (first expr) 'quote) (make-linfo expr empty gensym)]
       [(equal? (first expr) 'define-struct) (make-linfo expr empty gensym)]
       [else
        (local [(define arg-info (fold-anormal-help expr prims gensym))
                (define anormal-expr
                  (foldl (lambda (an-expr rest-args)
                           (if (and (cons? an-expr)
                                    (not (primitive? (first an-expr) prims)))
                               (make-linfo (cons (get-temp-symbol (linfo-gensym rest-args))
                                                 (linfo-return rest-args))
                                           (cons (list 'define
                                                       (get-temp-symbol
                                                        (linfo-gensym rest-args))
                                                       an-expr)
                                                 (linfo-raise rest-args))
                                           (add1 (linfo-gensym rest-args)))
                               (make-linfo (cons an-expr (linfo-return rest-args))
                                           (linfo-raise rest-args)
                                           (linfo-gensym rest-args))))
                         (make-linfo empty empty (linfo-gensym arg-info))
                         (linfo-return arg-info)))]
          (make-linfo (reverse (linfo-return anormal-expr))
                      (append (linfo-raise arg-info)
                              (reverse (linfo-raise anormal-expr)))
                      (linfo-gensym anormal-expr)))])]
    [else (make-linfo expr empty gensym)]))

;; make-anormal: (listof s-expr) env number -> linfo
;; consumes a list of symbolic expressions, an environment of primitives, and a gesym counter
;; returns linfo with the return being the completely anormalized expression,
;;    the raise being empty, and the gensym being the current gensym counter
(define (make-anormal expr prims gensym)
  (local [(define reversed-output
            (foldl (lambda (an-expr rest-exprs)
                     (local [(define anormal-expr (anormal-help an-expr
                                                                prims
                                                                (linfo-gensym rest-exprs)))]
                       (make-linfo (cons (if (empty? (linfo-raise anormal-expr))
                                             (linfo-return anormal-expr)
                                             (list 'local
                                                   (linfo-raise anormal-expr)
                                                   (linfo-return anormal-expr)))
                                         (linfo-return rest-exprs))
                                   empty
                                   (linfo-gensym anormal-expr))))
                   (make-linfo empty empty gensym)
                   expr))]
    (make-linfo (reverse (linfo-return reversed-output))
                empty
                (linfo-gensym reversed-output))))

;; anormalize: (listof s-expr) -> (listof s-expr)
;; takes a program in abstract syntax and rewrites it in anormal form
(define (anormalize program)
  (local [(define readied (ready-anormalize program))]
    (linfo-return (make-anormal readied (generate-prims readied) 0))))

(provide anormalize)