(module test-infix mzscheme
(require "infix.ss"
(lib "etc.ss")
(lib "list.ss")
(lib "plt-match.ss")
(lib "lex.ss" "parser-tools"))
(define (assert-sequence expected stx)
(define (expected-token? type token)
(symbol=? (token-name token) type))
(define (error name msg . args)
(raise-syntax-error name (apply format msg args) stx))
(let loop ([expected expected]
[experimental (syntax->token-list stx)])
(cond
[(and (empty? expected) (empty? experimental))
#t]
[(empty? expected)
(error 'assert-token-type-sequence "more than expected: ~a" experimental)]
[(empty? experimental)
(error 'assert-token-type-sequence "less than expected: ~a" expected)]
[(not (expected-token? (first expected) (first experimental)))
(error 'assert-token-type-sequence "mismatch ~a ~a" (first expected) (first experimental))]
[else (loop (rest expected) (rest experimental))])))
(define (assert-structure expected stx)
(define (error name msg . args)
(raise-syntax-error name (apply format msg args) stx))
(let loop ([expected expected]
[experimental (parse-expression (token-list->producer (syntax->token-list stx)))])
(match expected
[(list 'app op-1 rands-1 ...)
(match experimental
[(struct app-node (op-2 rands-2))
(if (and (= (length rands-1) (length rands-2))
(loop op-1 op-2)
(andmap loop rands-1 rands-2))
#t
(error 'assert-structure "mismatch 1 ~a ~a" expected experimental))]
[else
(error 'assert-structure "mismatch 2 ~a ~a" expected experimental)])]
[(list 'cmp op l r)
(match experimental
[(struct cmp-node (op-expr l-expr r-expr))
(if (and (loop op op-expr)
(loop l l-expr)
(loop r r-expr))
#t
(error 'assert-structure "mismatch 3 ~a ~a" expected experimental))]
[else
(error 'assert-structure "mismatch 4 ~a ~a" expected experimental)])]
[(list 'atom)
(match experimental
[(struct atom-node (atom-stx-2)) #t]
[else
(error 'assert-structure "mismatch 5 ~a ~a" expected experimental)])])))
(define (test-suite)
(assert-sequence '(lparen atom plus atom rparen)
(syntax (1 + 2)))
(assert-sequence '(lparen atom times atom rparen)
(syntax (1 * 2)))
(assert-sequence '(lparen atom comma lparen rparen rparen)
(syntax (foo, ())))
(assert-sequence '(lparen plus minus times divide rparen )
(syntax (+ - * /)))
(assert-sequence '(lparen atom lparen atom lparen atom lparen atom rparen rparen rparen rparen)
(syntax (f(x(y(z))))))
(assert-sequence '(lparen atom lparen atom comma atom rparen rparen)
(syntax (hello(world, testing))))
(assert-sequence '(lparen atom cmp atom cmp atom cmp atom cmp atom cmp atom rparen)
(syntax (a < b <= c > d >= e = f)))
(assert-structure '(atom) (syntax (42)))
(assert-structure '(app (atom)) (syntax (f ())))
(assert-structure '(app (atom)
(app (atom) (atom) (atom))
(atom))
(syntax (3 * 4 + 5)))
(assert-structure '(cmp (atom)
(cmp (atom) (atom) (atom))
(atom))
(syntax (3 < 4 <= 5)))
(assert-structure '(cmp (atom)
(app (atom) (atom))
(app (atom) (atom)))
(syntax (f(x) < f(y))))
(assert-structure '(cmp (atom)
(app (atom) (atom))
(app (atom) (atom)))
(syntax (f(x) = f(y))))
(assert-structure '(cmp (atom)
(cmp (atom) (app (atom) (atom)) (app (atom) (atom)))
(app (atom) (atom)))
(syntax (f(a) < f(y) < f(z))))
(assert-structure '(cmp (atom)
(cmp (atom) (app (atom) (atom)) (app (atom) (atom)))
(app (atom) (atom)))
(syntax (f(1) <= f(2) < f(3)))))
(test-suite)
)