subst-test.rkt
#lang racket/base
(require redex/reduction-semantics 
         "subst.rkt"
         racket/set)

(define (an-x? x) (memq x '(a b c x y z z2 z2 q)))

(test-equal (fvs an-x? (term (+ x a b))) (set 'x 'a 'b))
(test-equal (fvs an-x? (term (lambda (x) (+ x y)))) (set 'y))

(define-language L)
(define-metafunction L
  [(subst (any_x any_b) ... any_body)
   ,(subst/proc an-x? (term (any_x ...)) (term (any_b ...)) (term any_body))])

(test-equal (term (subst (x y) x)) (term y))
(test-equal (term (subst (x y) z)) (term z))
(test-equal (term (subst (x y) (x (y z)))) (term (y (y z))))
(test-equal (term (subst (x y) ((lambda (x) x) ((lambda (y1) y1) (lambda (x) z)))))
            (term ((lambda (x) x) ((lambda (y1) y1) (lambda (x) z)))))
(test-equal (term (subst (x y) (if0 (+ 1 x) x x)))
            (term (if0 (+ 1 y) y y)))
(test-equal (term (subst (x (lambda (z) y)) (lambda (y) x)))
            (term (lambda (y1) (lambda (z) y))))
(test-equal (term (subst (x 1) (lambda (y) x)))
            (term (lambda (y) 1)))
(test-equal (term (subst (x y) (lambda (y) x)))
            (term (lambda (y1) y)))
(test-equal (term (subst (x (lambda (y) y)) (lambda (z) (z2 z))))
            (term (lambda (z) (z2 z))))
(test-equal (term (subst (x (lambda (z) z)) (lambda (z) (z1 z))))
            (term (lambda (z) (z1 z))))
(test-equal (term (subst (x z) (lambda (z) (z1 z))))
            (term (lambda (z2) (z1 z2))))
(test-equal (term (subst (x3 5) (lambda (x2) x2)))
            (term (lambda (x2) x2)))
(test-equal (term (subst (z *) (lambda (z x) 1)))
            (term (lambda (z x) 1)))
(test-equal (term (subst (q (lambda (x) z)) (lambda (z x) q)))
            (term (lambda (z1 x) (lambda (x) z))))
(test-equal (term (subst (x 1) (lambda (x x) x)))
            (term (lambda (x x) x)))
(test-equal (term (subst (x (y z)) (lambda (z) (z (x y)))))
            (term (lambda (z1) (z1 ((y z) y)))))
(test-results)