(module reduction-semantics mzscheme
  (require "private/"
           (lib "")
           (lib "")
           (lib ""))
  (require-for-syntax (lib "")
                      (lib "" "syntax")
                      (lib "" "syntax"))

  ;; type red = (make-red compiled-pat ((listof (cons sym tst) (union string #f)) -> any)
  (define-struct red (contractum reduct name))

  (define-syntax (compatible-closure stx)
    (syntax-case stx ()
      [(_ red lang nt)
       (identifier? (syntax nt))
       (with-syntax ([side-conditions-rewritten (rewrite-side-conditions 'compatible-closure (syntax (cross nt)))])
         (syntax (do-context-closure red lang `side-conditions-rewritten 'compatible-closure)))]
      [(_ red lang nt)
       (raise-syntax-error 'compatible-closure "expected a non-terminal as last argument" stx (syntax nt))]))
  (define-syntax (context-closure stx)
    (syntax-case stx ()
      [(_ red lang pattern)
       (with-syntax ([side-conditions-rewritten (rewrite-side-conditions 'context-closure (syntax pattern))])
  (define (do-context-closure red lang pat name)
    (unless (reduction-relation? red)
      (error name "expected <reduction-relation> as first argument, got ~e" red))
    (unless (compiled-lang? lang)
      (error name "expected <lang> as second argument, got ~e" lang))
    (let ([cp (compile-pattern
               `(in-hole (name ctxt ,pat)
                         (name exp any)))])
        (λ (f)
          (λ (main-exp exp extend acc)
            (let loop ([ms (or (match-pattern cp exp) '())]
                       [acc acc])
                [(null? ms) acc]
                 (let* ([mtch (car ms)]
                        [bindings (mtch-bindings mtch)])
                   (loop (cdr ms)
                         (f main-exp
                            (lookup-binding bindings 'exp)
                            (λ (x) (extend (plug (lookup-binding bindings 'ctxt) x)))
        (reduction-relation-procs red))
       (reduction-relation-rule-names red))))
  (define (build-metafunction lang patterns rhss wrap name)
    (let ([compiled-patterns (map (λ (pat) (compile-pattern lang pat)) patterns)])
       (λ (exp)
         (let loop ([patterns compiled-patterns]
                    [rhss rhss]
                    [num 0])
             [(null? patterns) (error name "no clauses matched for ~s" exp)]
             [else (let ([pattern (car patterns)]
                         [rhs (car rhss)])
                     (let ([mtchs (match-pattern pattern exp)])
                         [(not mtchs) (loop (cdr patterns)
                                            (cdr rhss)
                                            (+ num 1))]
                         [(not (null? (cdr mtchs)))
                          (error name "clause ~a matched ~s two different ways" num exp)]
                          (rhs (mtch-bindings (car mtchs)))])))]))))))
  (define (do-test-match lang pat)
    (unless (compiled-lang? lang)
      (error 'test-match "expected first argument to be a language, got ~e" lang))
    (let ([cpat (compile-pattern lang pat)])
      (λ (exp)
        (match-pattern cpat exp))))

  (define-syntax (--> stx) (raise-syntax-error '--> "used outside of reduction-relation"))
  (define-syntax (fresh stx) (raise-syntax-error 'fresh "used outside of reduction-relation"))

  ;; procs : (listof (exp exp (any -> any) (listof any) -> (listof any)))
  ;; rule-names : (listof sym)
  (define-struct reduction-relation (procs rule-names))
  (define (apply-reduction-relation/tag-with-names p v)
    (let loop ([procs (reduction-relation-procs p)]
               [acc '()])
        [(null? procs) acc]
         (loop (cdr procs)
               ((car procs) v v values acc))])))
  (define (apply-reduction-relation p v) (map cadr (apply-reduction-relation/tag-with-names p v)))
  (define-syntax-set (-reduction-relation)
    (define (-reduction-relation/proc stx)
      (syntax-case stx ()
        [(_ lang args ...)
         (with-syntax ([(rules ...) (before-where (syntax (args ...)))]
                       [(shortcuts ...) (after-where (syntax (args ...)))])
            (syntax lang)
            (syntax->list (syntax (rules ...)))
            (syntax->list (syntax (shortcuts ...)))))]))
    (define (before-where stx)
      (let loop ([lst (syntax->list stx)])
          [(null? lst) null]
           (let ([fst (car lst)])
             (syntax-case* (car lst) (where) (λ (x y) (eq? (syntax-e x) (syntax-e y)))
               [where null]
               [else (cons (car lst) (loop (cdr lst)))]))])))
    (define (after-where stx)
      (let loop ([lst (syntax->list stx)])
          [(null? lst) null]
           (let ([fst (car lst)])
             (syntax-case* (car lst) (where) (λ (x y) (eq? (syntax-e x) (syntax-e y)))
               [where (cdr lst)]
               [else (loop (cdr lst))]))])))
    (define (reduction-relation/helper stx lang-exp rules shortcuts)
      (let ([ht (make-module-identifier-mapping)])
        (for-each (λ (shortcut)
                    (syntax-case shortcut ()
                      [((lhs-arrow lhs-from lhs-to)
                        (rhs-arrow rhs-from rhs-to))
                       (table-cons! ht (syntax rhs-arrow) shortcut)]))
        (for-each (λ (rule)
                    (syntax-case rule ()
                      [(arrow . rst)
                       (table-cons! ht (syntax arrow) rule)]))
        (unless (module-identifier-mapping-get ht (syntax -->) (λ () #f))
          (raise-syntax-error 'reduction-relation "no --> rules" stx))
        (let ([name-ht (make-hash-table)])
          (with-syntax ([lang-exp lang-exp]
                        [(top-level ...) (get-choices stx ht (syntax lang-x) (syntax -->) name-ht)]
                        [(rule-names ...) (hash-table-map name-ht (λ (k v) k))])
              (let ([lang-x lang-exp])
                (list top-level ...))
              '(rule-names ...)))))))
#|    ;; relation-tree = 
    ;;   leaf
    ;;  (make-node id[frm] pat[frm] id[to] pat[to] (listof relation-tree))
    (define-struct node (frm-id frm-pat to-id to-pat))
    (define-struct leaf (frm-pat to-pat))
    ;; get-choices : stx[original-syntax-object] bm lang identifier ht[sym->syntax] -> (listof relation-tree)
    (define (get-choices stx bm lang id name-table)
       (map (λ (x) (get-tree stx bm lang x name-table)) 
             bm id
             (λ ()
               (raise-syntax-error 'reduction-relation 
                                   (format "found no rules for ~a" (syntax-object->datum id))
    (define (get-tree stx bm lang case-stx name-table)
      (syntax-case case-stx ()
        [(arrow from to extras ...) 
         (do-leaf stx 
                  (syntax from) 
                  (syntax to) 
                  (syntax->list (syntax (extras ...))))]
        [((lhs-arrow lhs-frm-id lhs-to-id) (rhs-arrow rhs-from rhs-to))
         (let-values ([(names names/ellipses) (extract-names (syntax rhs-from))])
           (with-syntax ([(names ...) names]
                         [(names/ellipses ...) names/ellipses]
                         [side-conditions-rewritten (rewrite-side-conditions
                                                     (rewrite-node-pat (syntax-e (syntax lhs-frm-id))
                                                                       (syntax-object->datum (syntax rhs-from))))]
                         [lang lang]
                         [(child-procs ...) (get-choices stx bm lang (syntax lhs-arrow) name-table)])
               (λ (bindings rhs-binder)
                 (term-let ([lhs-to-id rhs-binder]
                            [names/ellipses (lookup-binding bindings 'names)] ...)
                           (term rhs-to)))
               (list child-procs ...)))))]))
    (define (rewrite-node-pat id term)
      (let loop ([term term])
          [(eq? id term) `(name ,id any)]
          [(pair? term) (cons (loop (car term))
                              (loop (cdr term)))]
          [else term])))

    (define (do-leaf stx lang name-table from to extras)
      (let-values ([(name fresh-vars side-conditions) (process-extras stx name-table extras)])
        (let-values ([(names names/ellipses) (extract-names from)])
          (with-syntax ([side-conditions-rewritten 
                          (if (null? side-conditions)
                              (with-syntax ([(sc ...) side-conditions]
                                            [from from])
                                (syntax (side-condition from (and sc ...))))))]
                        [to to]
                        [name name]
                        [lang lang]
                        [(names ...) names]
                        [(names/ellipses ...) names/ellipses]
                        [(fresh-vars ...) fresh-vars])
            (syntax (do-leaf-match
                     (λ (main bindings)
                       (term-let ([names/ellipses (lookup-binding bindings 'names)] 
                                  [fresh-vars (variable-not-in main 'fresh-vars)] 
                                 (term to)))))))))
    (define (process-extras stx name-table extras)
      (let ([the-name #f]
            [the-name-stx #f]
            [fresh-vars '()]
            [side-conditions '()])
        (let loop ([extras extras])
            [(null? extras) (values the-name fresh-vars side-conditions)]
             (syntax-case (car extras) (fresh)
                (or (identifier? (car extras))
                    (string? (syntax-e (car extras))))
                  (let* ([raw-name (syntax-e (car extras))]
                          (if (symbol? raw-name)
                              (string->symbol raw-name))])
                    (when (hash-table-get name-table name-sym #f)
                      (raise-syntax-errors 'reduction-relation 
                                           "same name on multiple rules"
                                           (list (hash-table-get name-table name-sym)
                                                 (syntax name))))
                    (hash-table-put! name-table name-sym (syntax name))
                    (when the-name
                      (raise-syntax-errors 'reduction-relation "expected only a single name" 
                                           (list the-name-stx (car extras))))
                    (set! the-name (if (symbol? raw-name)
                                       (symbol->string raw-name)
                    (set! the-name-stx (car extras))
                    (loop (cdr extras))))]
               [(fresh var ...)
                (andmap identifier? (syntax->list (syntax (var ...))))
                  (set! fresh-vars (append (syntax->list (syntax (var ...))) fresh-vars))
                  (loop (cdr extras)))]
               [(fresh exp ...)
                 (λ (x) (unless (identifier? x)
                          (raise-syntax-error 'reduction-relation "expected variables in a fresh clause" stx x)))
                 (syntax->list (syntax (exp ...))))]
               [(side-condition exp ...)
                  (set! side-conditions
                        (append (syntax->list (syntax (exp ...))) side-conditions))
                  (loop (cdr extras)))]
                (raise-syntax-error 'reduction-relation "unknown extra" stx (car extras))])]))))

    ;; table-cons! hash-table sym any -> void
    ;; extends ht at key by `cons'ing hd onto whatever is alrady bound to key (or the empty list, if nothing is)
    (define (table-cons! ht key hd)
      (module-identifier-mapping-put! ht key (cons hd (module-identifier-mapping-get ht key (λ () '())))))
    (define (raise-syntax-errors sym str stx stxs)
      (raise (make-exn:fail:syntax (string->immutable-string (format "~a: ~a" sym str))
                                   (apply list-immutable stxs)))))
  (define (union-reduction-relations fst snd . rst)
    (let ([name-ht (make-hash-table)]
          [lst (list* fst snd rst)])
       (λ (red)
         (for-each (λ (name)
                     (when (hash-table-get name-ht name #f)
                       (error 'union-reduction-relations "multiple rules with the name ~s" name))
                     (hash-table-put! name-ht name #t))
                   (reduction-relation-rule-names red)))
       (reverse (apply append (map reduction-relation-procs lst)))
       (hash-table-map name-ht (λ (k v) k)))))
  (define (do-node-match lang lhs-frm-id lhs-to-id pat rhs-proc child-procs)
    (let ([cp (compile-pattern lang pat)])
      (λ (main-exp exp f other-matches)
        (let ([mtchs (match-pattern cp exp)])
          (if mtchs
              (let o-loop ([mtchs mtchs]
                           [acc other-matches])
                  [(null? mtchs) acc]
                   (let ([sub-exp (lookup-binding (mtch-bindings (car mtchs)) lhs-frm-id)])
                     (let i-loop ([child-procs child-procs]
                                  [acc acc])
                         [(null? child-procs) (o-loop (cdr mtchs) acc)]
                         [else (i-loop (cdr child-procs)
                                       ((car child-procs) main-exp
                                                          (λ (x) (f (rhs-proc (mtch-bindings (car mtchs)) x)))
  (define (do-leaf-match lang name pat proc)
    (let ([cp (compile-pattern lang pat)])
      (λ (main-exp exp f other-matches)
        (let ([mtchs (match-pattern cp exp)])
          (if mtchs
              (map/mt (λ (mtch) (list name (f (proc main-exp (mtch-bindings mtch)))))
  (define-syntax (test-match stx)
    (syntax-case stx ()
      [(_ lang-exp pattern)
       (with-syntax ([side-condition-rewritten (rewrite-side-conditions 'test-match (syntax pattern))])
          (do-test-match lang-exp `side-condition-rewritten)))]
      [(_ lang-exp pattern expression)
        ((test-match lang-exp pattern) expression))]))
  (define-syntax (define-metafunction stx)
    (syntax-case stx ()
      [(_ name lang-exp (lhs rhs) ...)
       (with-syntax ([(side-conditions-rewritten ...) (map (λ (x) (rewrite-side-conditions 'define-metafunction x))
                                                           (syntax->list (syntax (lhs ...))))]
                     [(rhs-fns ...)
                      (map (λ (lhs rhs)
                             (let-values ([(names names/ellipses) (extract-names lhs)])
                               (with-syntax ([(names ...) names]
                                             [(names/ellipses ...) names/ellipses]
                                             [rhs rhs])
                                  (λ (bindings) (term-let ([names/ellipses (lookup-binding bindings 'names)] ...)
                                                          (term-let-fn (#;(name name))
                                                                       (term rhs))))))))
                           (syntax->list (syntax (lhs ...)))
                           (syntax->list (syntax (rhs ...))))]
                     [(name2) (generate-temporaries (syntax (name)))])
            (define name2
               (list `side-conditions-rewritten ...)
               (list rhs-fns ...)
               (λ (f) (let ([name (lambda (x) (f x))]) name))
            (term-define-fn name name2))))]))
  (define-syntax (metafunction stx)
    (syntax-case stx ()
      [(_ lang-exp (lhs rhs) ...)
       (with-syntax ([(side-conditions-rewritten ...) (map (λ (x) (rewrite-side-conditions 'metafunction))
                                                           (syntax->list (syntax (lhs ...))))]
                     [(rhs-fns ...)
                      (map (λ (lhs rhs)
                             (let-values ([(names names/ellipses) (extract-names lhs)])
                               (with-syntax ([(names ...) names]
                                             [(names/ellipses ...) names/ellipses]
                                             [rhs rhs])
                                  (λ (bindings) (term-let ([names/ellipses (lookup-binding bindings 'names)] ...)
                                                          (term rhs)))))))
                           (syntax->list (syntax (lhs ...)))
                           (syntax->list (syntax (rhs ...))))]
                     [name (or (syntax-local-infer-name stx) 'metafunction)])
           (list `side-conditions-rewritten ...)
           (list rhs-fns ...)
           (λ (f) (let ([name (lambda (x) (f x))]) name))
  (define-syntax (language stx)
    (syntax-case stx ()
      [(_ (name rhs ...) ...)
       (andmap identifier? (syntax->list (syntax/loc stx (name ...))))
       (with-syntax ([((r-rhs ...) ...) (map (lambda (rhss) (map (λ (x) (rewrite-side-conditions 'language x)) (syntax->list rhss)))
                                             (syntax->list (syntax ((rhs ...) ...))))]
                     [(refs ...)
                      (let loop ([stx (syntax ((rhs ...) ...))])
                          [(identifier? stx)
                           (if (ormap (λ (x) (module-identifier=? x stx)) 
                                      (syntax->list (syntax (name ...))))
                               (list stx)
                          [(syntax? stx)
                           (loop (syntax-e stx))]
                          [(pair? stx)
                           (append (loop (car stx))
                                   (loop (cdr stx)))]
                          [else '()]))])
         (syntax/loc stx
               (let ([name 1] ...)
                 (begin (void) refs ...))
               (compile-language (list (make-nt 'name (list (make-rhs `r-rhs) ...)) ...))))))]
      [(_ (name rhs ...) ...)
        (lambda (name)
          (unless (identifier? name)
            (raise-syntax-error 'language "expected name" stx name)))
        (syntax->list (syntax (name ...))))]
      [(_ x ...)
        (lambda (x)
          (syntax-case x ()
            [(name rhs ...)
             (raise-syntax-error 'language "malformed non-terminal" stx x)]))
        (syntax->list (syntax (x ...))))]))
  (define-syntax (extend-language stx)
    (syntax-case stx ()
      [(_ lang (name rhs ...) ...)
       (andmap identifier? (syntax->list (syntax/loc stx (name ...))))
       (with-syntax ([((r-rhs ...) ...) (map (lambda (rhss) (map (λ (x) (rewrite-side-conditions 'extend-language x)) (syntax->list rhss)))
                                             (syntax->list (syntax ((rhs ...) ...))))])
         (syntax/loc stx
           (do-extend-language lang (list (make-nt 'name (list (make-rhs `r-rhs) ...)) ...))))]
      [(_ lang (name rhs ...) ...)
        (lambda (name)
          (unless (identifier? name)
            (raise-syntax-error 'extend-language "expected name" stx name)))
        (syntax->list (syntax (name ...))))]
      [(_ lang x ...)
        (lambda (x)
          (syntax-case x ()
            [(name rhs ...)
             (raise-syntax-error 'extend-language "malformed non-terminal" stx x)]))
        (syntax->list (syntax (x ...))))]))
  (define extend-nt-ellipses '(....))
  (define (do-extend-language old-lang new-nts)
    (let ([old-nts (compiled-lang-lang old-lang)]
          [old-ht (make-hash-table)]
          [new-ht (make-hash-table)])
      (for-each (λ (nt) 
                  (hash-table-put! old-ht (nt-name nt) nt)
                  (hash-table-put! new-ht (nt-name nt) nt))
      (for-each (λ (nt)
                    [(ormap (λ (rhs) (member (rhs-pattern rhs) extend-nt-ellipses))
                            (nt-rhs nt))
                     (unless (hash-table-get old-ht (nt-name nt) #f)
                       (error 'extend-language "the language extends the ~s non-terminal, but that non-terminal is not in the old language"
                              (nt-name nt)))
                     (hash-table-put! new-ht 
                                      (nt-name nt)
                                       (nt-name nt)
                                       (append (nt-rhs (hash-table-get old-ht (nt-name nt)))
                                               (filter (λ (rhs) (not (member (rhs-pattern rhs) extend-nt-ellipses)))
                                                       (nt-rhs nt)))))]
                     (hash-table-put! new-ht (nt-name nt) nt)]))
      (compile-language (hash-table-map new-ht (λ (x y) y)))))
  ;; reduce : (listof red) exp -> (listof exp)
  (define (reduce reductions exp)
    (reduce/internal reductions exp (λ (red) (λ (mtch) ((red-reduct red) (mtch-bindings mtch))))))
  ; reduce/tag-with-reductions : (listof red) exp -> (listof (list red exp))
  (define (reduce/tag-with-reduction reductions exp)
    (reduce/internal reductions exp (λ (red) (λ (mtch) (list red ((red-reduct red) (mtch-bindings mtch)))))))
  ; reduce/internal : (listof red) exp (red -> match -> X) -> listof X
  (define (reduce/internal reductions exp f)
    (let loop ([reductions reductions]
               [acc null])
        [(null? reductions) acc]
        [else (let* ([red (car reductions)]
                     [mtchs (match-pattern (red-contractum red) exp)])
                (if mtchs
                    (loop (cdr reductions)
                          (map/mt (f red) mtchs acc))
                    (loop (cdr reductions) acc)))])))
  (define (apply-reduction-relation* reductions exp)
    (let ([answers (make-hash-table)])
      (let loop ([exp exp])
        (let ([nexts (apply-reduction-relation reductions exp)])
          (let ([uniq (mk-uniq nexts)])
            (unless (= (length uniq) (length nexts))
              (error 'reduce-all "term ~s produced non unique results:~a"
                      (map (λ (x) (format "\n~s" x)) nexts))))
              [(null? uniq) (hash-table-put! answers exp #t)]
              [else (for-each loop uniq)]))))
      (hash-table-map answers (λ (x y) x))))
  ;; mk-uniq : (listof X) -> (listof X)
  ;; returns the uniq elements (according to equal?) in terms.
  (define (mk-uniq terms)
    (let ([ht (make-hash-table 'equal)])
      (for-each (λ (x) (hash-table-put! ht x #t)) terms)
      (hash-table-map ht (λ (k v) k))))
  ;; map/mt : (a -> b) (listof a) (listof b) -> (listof b)
  ;; map/mt is like map, except it uses the last argument
  ;; instaed of the empty list
  (define (map/mt f l mt-l)
    (let loop ([l l])
        [(null? l) mt-l]
        [else (cons (f (car l)) (loop (cdr l)))])))
  (define re:gen-d #rx".*[^0-9]([0-9]+)$")
  (define (variable-not-in sexp var)
    (let* ([var-str (symbol->string var)]
           [nums (let loop ([sexp sexp]
                            [nums null])
                     [(pair? sexp) (loop (cdr sexp) (loop (car sexp) nums))]
                     [(symbol? sexp) (let* ([str (symbol->string sexp)]
                                            [match (regexp-match re:gen-d str)])
                                       (if (and match
                                                (is-prefix? var-str str))
                                           (cons (string->number (cadr match)) nums)
                     [else nums]))])
      (if (null? nums)
          (string->symbol (format "~a1" var))
          (string->symbol (format "~a~a" var (+ 1 (apply max nums)))))))
  (define (is-prefix? str1 str2)
    (and (<= (string-length str1) (string-length str2))
         (equal? str1 (substring str2 0 (string-length str1)))))
  (provide (rename -reduction-relation reduction-relation) 
           --> fresh ;; keywords for reduction-relation

  (provide test-match
           make-bindings bindings-table bindings?
           mtch? mtch-bindings mtch-context  mtch-hole
           make-rib rib? rib-name rib-exp
   [set-cache-size! (-> number? void?)]
   [apply-reduction-relation (-> reduction-relation? any/c (listof any/c))]
    (-> reduction-relation? any/c (listof (list/c (union false/c string?) any/c)))]
   [apply-reduction-relation* (-> reduction-relation? any/c (listof any/c))]
   [union-reduction-relations (->* (reduction-relation? reduction-relation?)
                                   (listof reduction-relation?)
   [lookup-binding (case-> 
                    (-> bindings? symbol? any)
                    (-> bindings? symbol? (-> any) any))]
   (variable-not-in (any/c symbol? . -> . symbol?))))