(module ho-contracts mzscheme
  (require "../reduction-semantics.ss"
           (lib "mred.ss" "mred")
           (lib "framework.ss" "framework")
           (lib "class.ss")
           (lib "match.ss")
           (lib "list.ss"))
  (initial-font-size 7) (reduction-steps-cutoff 10)
  ;(initial-font-size 36) (reduction-steps-cutoff 1)
  (define lang
     (p ((d ...) e))
     (d (valrec x : e = e))
     (e (lambda (x) e)
        (e e)
        (let ((x e) ...) e)
        (fix x e)
        (aop e e)
        (rop e e)
        (cons e e)
        (hd e)
        (tl e)
        (mt e)
        (if e e e)
        (--> e e)
        (contract e)
        (flatp e)
        (pred e)
        (dom e)
        (rng e)
        (blame e))
     (x (variable-except valrec lambda let fix aop rop cons empty hd tl mt if true false --> contract flatp pred dom rng blame))
     (p-ctxt (((valrec x : v = v) ...
               (valrec x : e-ctxt = e)
               d ...)
             (((valrec x : v = v) ...
               (valrec x : v = e-ctxt)
               d ...)
             (((valrec x : v = v) ...)
     (e-ctxt (e-ctxt e)
             (v e-ctxt)
             (let ((x v) ... (x e-ctxt) (x e) ...) e)
             (aop e-ctxt e)
             (aop v e-ctxt)
             (rop e-ctxt e)
             (rop v e-ctxt)
             (cons e-ctxt e)
             (cons v e-ctxt)
             (hd e-ctxt)
             (tl e-ctxt)
             (mt e-ctxt)
             (if e-ctxt e e)
             (--> v e-ctxt)
             (--> e-ctxt e)
             (contract e-ctxt)
             (flatp e-ctxt)
             (pred e-ctxt)
             (dom e-ctxt)
             (rng e-ctxt)
             (blame e-ctxt)
     (v (cons v v)
        (lambda (x) e)
        (--> v v)
        (contract v)
        (ob v (--> v v) x x)
     (aop + - * /)
     (rop = >=)))
  (define ho-contracts-subst
     [`(let ([,a-vars ,rhs-exps] ...) ,body)
      (all-vars a-vars)
      (build (lambda (vars body . rhss) 
               `(let (,@(map (lambda (var rhs) `[,var ,rhs]) vars rhss))
      (subterm a-vars body)
      (subterms '() rhs-exps)]
     [`(lambda (,var) ,body)
      (all-vars (list var))
      (build (lambda (vars body) `(lambda (,(car vars)) ,body)))
      (subterm (list var) body)]
     [`(fix ,variable ,e)
      (all-vars (list variable))
      (build (lambda (vars body) `(fix ,(car vars) ,body)))
      (subterm (list variable) e)]
     [(? number?) (constant)]
     [`(,(and op (? (lambda (x) (memq x '(cons + - = > / -->))))) ,e1 ,e2)
      (all-vars '())
      (build (lambda (vars e1 e2) `(,op ,e1 ,e2)))
      (subterm '() e1)
      (subterm '() e2)]
     [`empty (constant)]
     [`(,(and op (? (lambda (x) (memq x '(hd tl mt contract flat pred dom rng blame))))) ,e1)
      (all-vars '())
      (build (lambda (vars e1) `(,op ,e1)))
      (subterm '() e1)]
     [`(if ,e1 ,e2 ,e3)
      (all-vars '())
      (build (lambda (vars e1 e2 e3) `(if ,e1 ,e2 ,e3)))
      (subterm '() e1)
      (subterm '() e2)
      (subterm '() e3)]
     [`(,e1 ,e2)
      (all-vars '())
      (build (lambda (vars e1 e2) `(,e1 ,e2)))
      (subterm '() e1)
      (subterm '() e2)]
     [`true (constant)]
     [`false (constant)]
     [(? string?) (constant)]
     [(? symbol?) (variable)]))
  (define reductions
     (reduction lang  
                (in-hole (name p p-ctxt) (/ number_n number_m)) 
                (if (= (term number_m) 0)
                    (term (error /))
                    (plug (term p) (/ (term number_n) (term number_m)))))
     (reduction lang  
                (in-hole (name p p-ctxt) (* number_n number_m))
                (plug (term p) (* (term number_n) (term number_m))))
     (reduction lang  
                (in-hole (name p p-ctxt) (+ number_n number_m))
                (plug (term p) (+ (term number_n) (term number_m))))
     (reduction lang  
                (in-hole (name p p-ctxt) (- number_n number_m))
                (plug (term p) (- (term number_n) (term number_m))))
     (reduction lang  
                (in-hole (name p p-ctxt) (>= number_n number_m))
                (plug (term p) (if (>= (term number_n) (term number_m)) 'true 'false)))
     (reduction lang  
                (in-hole (name p p-ctxt) (= number_n number_m))
                (plug (term p) (if (= (term number_n) (term number_m)) 'true 'false)))
     (reduction lang  
                (in-hole (name p p-ctxt) ((lambda (variable_x) e_body) v_arg))
                (plug (term p) (ho-contracts-subst (term variable_x) 
                                                   (term v_arg) 
                                                   (term e_body))))
     (reduction lang  
                (in-hole (name p p-ctxt) 
                         (let ((variable_i v_i) ...) e_body))
                (plug (term p) 
                       (term e_body)
                       (term (variable_i ...))
                       (term (v_i ...)))))
     (reduction lang  
                (in-hole (name p p-ctxt) (name tot (fix (name x variable) (name body e))))
                (plug (term p) (ho-contracts-subst (term x) (term tot) (term body))))
     (reduction lang  
                ((name defns
                       ((valrec (name bvar variable) : (name bctc value) = (name brhs value)) ...
                        (valrec (name var variable) : value = (name rhs value))
                        (valrec variable : value = value) ...))
                 (in-hole (name p e-ctxt) (name var variable)))
                (term (defns ,(plug (term p) (term rhs)))))
     (reduction lang
                (in-hole (name p p-ctxt) (if true e_then e))
                (plug (term p) (term e_then)))
     (reduction lang  
                (in-hole (name p p-ctxt) (if false e e_else))
                (plug (term p) (term e_else)))
     (reduction lang  
                (in-hole (name p p-ctxt) (hd (cons v_fst v)))
                (plug (term p) (term v_fst)))
     (reduction lang  
                (in-hole (name p p-ctxt) (hd empty))
                (term (error hd)))
     (reduction lang  
                (in-hole (name p p-ctxt) (tl (cons v v_rst)))
                (plug (term p) (term v_rst)))
     (reduction lang  
                (in-hole (name p p-ctxt) (tl empty))
                (term (error tl)))
     (reduction lang  
                (in-hole (name p p-ctxt) (mt empty))
                (plug (term p) 'true))
     (reduction lang  
                (in-hole (name p p-ctxt) (mt (cons v v)))
                (plug (term p) 'false))
     (reduction lang  
                (in-hole (name p p-ctxt) (flatp (contract v)))
                (plug (term p) 'true))
     (reduction lang  
                (in-hole (name p p-ctxt) (flatp (--> v v)))
                (plug (term p) 'false))
     (reduction lang  
                (in-hole (name p p-ctxt) (pred (contract v_pred)))
                (plug (term p) (term v_pred)))
     (reduction lang  
                (in-hole (name p p-ctxt) (pred (--> v v)))
                (term (error pred)))
     (reduction lang  
                (in-hole (name p p-ctxt) (dom (--> v_dm v)))
                (plug (term p) (term v_dm)))
     (reduction lang  
                (in-hole (name p p-ctxt) (dom (contract v)))
                (term (error dom)))
     (reduction lang  
                (in-hole (name p p-ctxt) (rng (--> v v_rg)))
                (plug (term p) (term v_rg)))
     (reduction lang  
                (in-hole (name p p-ctxt) (rng (contract v)))
                (term (error rng)))
     (reduction lang  
                (in-hole (name p p-ctxt) (blame (name x variable)))
                (term (error x)))))
  (define (pp v port w spec)
    (parameterize ([current-output-port port])
      (pp-prog v spec)))
  (define (pp-prog prog spec)
    (for-each (lambda (x) (pp-defn x spec)) (car prog))
    (pp-expr (cadr prog) 0 spec)
    (display "\n"))
  (define (pp-defn defn spec)
    (let ([var (second defn)]
          [ctc (fourth defn)]
          [exp (sixth defn)])
      (printf "val rec ")
      (display var)
      (display " : ")
      (pp-expr ctc 0 spec)
      (display " = ")
      (pp-expr exp 0 spec)
      (display "\n")))
  (define (dp/ct x)
    (let* ([str (format "~a" x)]
           [ct (string-length str)])
      (display str)
  ;; pp-expr : sexp number (snip -> void) -> (union #f number)
  ;; returns #f if it started a new line and a
  ;; number if it didn't. The number indicates
  ;; how many columns were printed
  (define (pp-expr x nl-col text)
      [(equal? x wrapbar)
       (insert-wrapbar text)]
       (match x
         [`(lambda (,v) ,e)
          (insert-lambda text)
          (insert-variable text v)
          (dp/ct ". ")
          (next-line (+ nl-col 2))
          (pp-expr e (+ nl-col 2) text)
         [`(let ((,vs ,rhss) ...) ,body)
          (insert-bold text "let")
          (dp/ct " ")
          (insert-variable text (car vs))
          (dp/ct " = ")
          (pp-expr (car rhss) (+ nl-col (string-length (format "let "))) text)
          (for-each (lambda (v rhs)
                      (next-line (+ nl-col (string-length (format "let "))))
                      (insert-variable text v)
                      (dp/ct " = ")
                      (pp-expr rhs (+ nl-col (string-length (format "let "))) text))
                    (cdr vs)
                    (cdr rhss))
          (next-line (+ nl-col 2))
          (pp-expr body (+ nl-col 2) text)
         [`(fix ,v ,e)
          (insert-bold text "fix")
          (dp/ct " ")
          (dp/ct v)
          (dp/ct ". ")
          (next-line (+ nl-col 2))
          (pp-expr e (+ nl-col 2) text)
         [`(if ,e1 ,e2 ,e3)
          (insert-bold text "if")
          (dp/ct " ")
          (pp-expr e1 (+ nl-col 3) text)
          (next-line (+ nl-col 2))
          (insert-bold text "then")
          (dp/ct " ")
          (pp-expr e2 (+ nl-col 2 5) text)
          (next-line (+ nl-col 2))
          (insert-bold text "else")
          (dp/ct " ")
          (pp-expr e3 (+ nl-col 2 5) text)
         [`(,e1 ,e2)
          (let* ([fst-res
                    [(simple? e1)
                     (pp-expr e1 nl-col text)]
                     (comb (dp/ct "(")
                           (pp-expr e1 (+ nl-col 1) text)
                           (dp/ct ")"))])]
                  (or (not fst-res)
                      (>= fst-res 10))]
                 [_ (when break-lines?
                      (next-line nl-col))]
                    [(simple? e2)
                      (dp/ct " ")
                      (pp-expr e2 
                               (if break-lines?
                                   (+ nl-col 1)
                                   (+ fst-res nl-col 1))
                     (comb (dp/ct " (")
                           (pp-expr e2 
                                    (if break-lines?
                                        (+ nl-col 2)
                                        (+ nl-col fst-res 2))
                           (dp/ct ")"))])])
         [`(,biop ,e1 ,e2)
          (let* ([fst-res
                    [(simple? e1)
                     (pp-expr e1 nl-col text)]
                      (dp/ct "(")
                      (pp-expr e1 (+ nl-col 1) text)
                      (dp/ct ")"))])]
                 [spc1 (if fst-res
                           (dp/ct " ")
                           (begin (next-line nl-col)
                  (case biop
                    [(cons) (dp/ct "::")]
                     (insert-symbol text (string (integer->char 174)))
                    [else (dp/ct biop)])]
                 [spc2 (if fst-res
                           (dp/ct " ")
                           (begin (next-line nl-col)
                    [(simple? e2) 
                     (pp-expr e2 
                              (if fst-res
                                  (+ nl-col fst-res spc1 middle spc2)
                      (dp/ct "(")
                      (pp-expr e2 
                               (if fst-res
                                   (+ nl-col fst-res spc1 middle spc2 1)
                                   (+ nl-col 1))
                      (dp/ct ")"))])])
            (comb fst-res
          (insert-compile text "V" "1")]
          (insert-compile text "V" "2")]
          (insert-compile text "e" #f)]
          (dp/ct "[]")]
         [(? (lambda (x) 
               (and (symbol? x)
                    (memq x keywords))))
          (insert-bold text (symbol->string x))]
         [(? symbol?)
          (insert-variable text x)]
          (dp/ct (format "~s" x))])]))
  (define keywords '(contract pred dom rng flatp error blame true false))

  (define (insert-lambda text)
    (insert-symbol text "l "))

  (define (insert-compile text arg subscript)
    (let ([sd (make-object style-delta% 'change-family 'script)]
          [b-sd (make-object style-delta% 'change-family 'script)])
      (send b-sd set-delta 'change-bold)
      (+ (insert/style text "C" b-sd)
         (insert/snip text (make-object sub-snip% "e"))
         (insert/style text "(" sd)
         (insert/style text arg #f)
         (if subscript
             (insert/snip text (make-object sub-snip% subscript))
         (insert/style text ")" sd))))

  (define (insert-wrapbar text)
    (send text insert (make-object wrap-bar%)
          (send text last-position)
          (send text last-position))
  (define (insert-symbol text str)
    (insert/style text str (make-object style-delta% 'change-family 'symbol)))
  (define (insert-bold text str)
    (insert/style text str (make-object style-delta% 'change-bold))
    (string-length str))
  (define (insert-variable text sym)
    (let ([d (make-object style-delta%)]
          [str (symbol->string sym)])
      (send d set-delta-foreground "forest green")
      (insert/style text str d)
      (string-length str)))
  (define wrap-bar%
    (class snip%
      (inherit get-style)
      (define/override (get-extent dc x y wb hb db ab lspace rspace)
        (set-box/f lspace 0)
        (set-box/f rspace 0)
        (let-values ([(w h d a) (send dc get-text-extent "wrap"
                                      (send (get-style) get-font))])
          (set-box/f wb w)
          (set-box/f hb h)
          (set-box/f db d)
          (set-box/f ab a)))
      (define/override (draw dc x y left top right bottom dx dy draw-caret?)
        (let-values ([(w h d a) (send dc get-text-extent "wrap")])
          (send dc draw-text "wrap" x y)
          (send dc draw-line x (+ y 1) (+ x w -1) (+ y 1))))
      (super-instantiate ())))

  (define sub-snip%
    (class snip%
      (init-field str)
      (inherit get-style)
      (define/override (get-extent dc x y wb hb db ab lspace rspace)
        (set-box/f lspace 0)
        (set-box/f rspace 0)
        (let-values ([(w h d a) (send dc get-text-extent str
                                      (send (get-style) get-font))])
          (set-box/f wb w)
          (set-box/f hb (+ h (floor (/ h 3))))
          (set-box/f db (- (+ h (floor (/ h 3)))
                           (- h d)))
          (set-box/f ab a)))
      (define/override (draw dc x y left top right bottom dx dy draw-caret?)
        (let-values ([(w h d a) (send dc get-text-extent str)])
          (send dc draw-text str x (+ y (floor (/ h 3))))))
      (super-instantiate ())))
  (define (set-box/f b v) (when (box? b) (set-box! b v)))
  ;; insert/snip : text snip -> number
  ;; returns an approximation to the width of what was inserted
  (define (insert/snip text snip)
    (send text insert snip (send text last-position) (send text last-position))
  ;; insert/style : text string style-delta% -> number
  ;; returns the number of characters in the string
  ;; (an approximation to the width of what was inserted)
  (define (insert/style text str sd)
    (let ([pos (send text last-position)])
      (send text insert str pos pos)
      (when sd
        (send text change-style 
              (send text last-position)))
      (string-length str)))
  ;; comb : (union #f number) *-> (union #f number)
  ;; sums up its arguments, unless it gets #f,
  ;; in which case it returns #f
  (define (comb . x)
    (if (memq #f x)
        (apply + x)))
  ;; simple : any -> bool
  ;; determines if an expression need parenthesis
  (define (simple? exp)
    (or (not (pair? exp))
        (equal? exp wrapbar)))
  ;; next-line : number -> void
  ;; dp/cts a newline and indents to the proper place
  (define (next-line n)
    (dp/ct "\n")
    (let loop ([n n])
      (unless (zero? n)
        (dp/ct " ")
        (loop (- n 1)))))
  (define wrapbar
      (lambda (ct)
        (lambda (x)
          (lambda (p)
            (lambda (n)
              (if (flatp ct)
                  (if ((pred ct) x)
                      (blame p))
                  (let ((d (dom ct))
                        (r (rng ct)))
                    (lambda (y)
                      ((((wrap r)
                         (x ((((wrap d) y) n) p)))
  (define flat-case
    `(();; defns...
      ((((,wrapbar (contract compile-v2)) compile-v1) "p") "n")))
  (define ho-case
    `(();; defns...
      ((((,wrapbar (--> compile-v1 compile-v2)) (lambda (x) compile-e)) "p") "n")))
  (traces lang reductions flat-case pp)
  (traces lang reductions ho-case pp)