deriv-lang.ss
#|  deriv-lang.ss: MzScheme language for automatic differentiation.
    Copyright (C) 2007 Will M. Farr <farr@mit.edu>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
|#

(module deriv-lang mzscheme
  (require (lib "plt-match.ss"))
  
  (provide D partial
           (all-from-except mzscheme
                            + - / *
                            = < > <= >=
                            zero? positive? negative? odd? even?
                            abs exp log
                            sin cos tan
                            asin acos atan
                            sqrt expt)
           (rename my-+ +) (rename my-- -) (rename my-* *) (rename my-/ /)
           (rename my-= =) (rename my-< <) (rename my-> >)
           (rename my-<= <=) (rename my->= >=)
           (rename my-zero? zero?) (rename my-positive? positive?)
           (rename my-negative? negative?) (rename my-odd? odd?)
           (rename my-even? even?)
           (rename my-abs abs) (rename my-exp exp) (rename my-log log)
           (rename my-sin sin) (rename my-cos cos) (rename my-tan tan)
           (rename my-asin asin) (rename my-acos acos) (rename my-atan atan)
           (rename my-sqrt sqrt) (rename my-expt expt))
  
  (define-struct deriv
    (tag x dx) #f)
  
  ;; We need to use a channel here because we want generating id's for tags to be
  ;; thread-safe.
  (define tag-channel (make-channel))
  
  (thread
   (lambda ()
     (let loop ((i 0))
       (channel-put tag-channel i)
       (loop (+ i 1)))))
  
  (define (next-tag)
    (channel-get tag-channel))
  
  
  (define (extract-derivative tag x)
    (when (not (deriv? x))
      (raise-mismatch-error 'extract-derivative "not a deriv struct: " x))
    (let ((x-tag (deriv-tag x)))
      (cond
        ((= x-tag tag)
         (deriv-dx x))
        ((< x-tag tag)
         (make-deriv x-tag (extract-derivative tag (deriv-x x)) (extract-derivative tag (deriv-dx x))))
        (else
         (raise-mismatch-error 'extract-derivative "tag not found!" tag)))))
  
  (define (D f)
    (lambda (x)
      (let* ((tag (next-tag))
             (x (make-deriv tag x 1))
             (result (f x)))
        (extract-derivative tag result))))
  
  (define (tag-ith list tag i)
    (cond
      ((null? list) (raise-mismatch-error 'tag-ith "list too short" list))
      ((= i 0) (cons (make-deriv tag (car list) 1)
                     (cdr list)))
      (else (cons (car list) (tag-ith (cdr list) tag (- i 1))))))
  
  (define ((partial i) f)
    (lambda args
      (let* ((tag (next-tag))
             (tagged-args (tag-ith args tag i))
             (result (apply f tagged-args)))
        (extract-derivative tag result))))
  
  (define-syntax lift1
    (syntax-rules ()
      ((lift1 ndf f df)
       (match-lambda
         ((struct deriv (tag x dx))
          (make-deriv tag (f x) (mul (df x) dx)))
         (x (ndf x))))))
  
  (define-syntax lift2
    (syntax-rules ()
      ((lift2 ndf f dfdx dfdy)
       (lambda (xx yy)
         (match xx
           ((struct deriv (tagx x dx))
            (match yy
              ((struct deriv (tagy y dy))
               (cond
                 ((= tagx tagy)
                  (make-deriv tagx (f x y) (add (mul (dfdx x y) dx)
                                                (mul (dfdy x y) dy))))
                 ((< tagx tagy)
                  (make-deriv tagx (f x yy) (mul (dfdx x yy) dx)))
                 (else
                  (make-deriv tagy (f xx y) (mul (dfdy xx y) dy)))))
              (y
               (make-deriv tagx (f x y) (mul (dfdx x y) dx)))))
           (x
            (match yy
              ((struct deriv (tagy y dy))
               (make-deriv tagy (f x y) (mul (dfdy x y) dy)))
              (y
               (ndf x y)))))))))
  
  (define my-+
    (case-lambda
      (() 0)
      ((x) x)
      ((x y) (add x y))
      ((x y . zs)
       (apply my-+ (add x y) zs))))
  
  (define my--
    (case-lambda
      ((x) (negate x))
      ((x y) (sub x y))
      ((x . ys)
       (sub x (apply my-+ ys)))))
  
  (define my-*
    (case-lambda
      (() 1)
      ((x) x)
      ((x y) (mul x y))
      ((x y . zs)
       (apply my-* (mul x y) zs))))
  
  (define my-/
    (case-lambda
      ((x) (invert x))
      ((x y) (div x y))
      ((x . ys)
       (div x (apply my-* ys)))))
  
  (define add
    (let ((dadddarg (lambda (x y) 1)))
      (lift2 + add dadddarg dadddarg)))
  
  (define negate
    (let ((Dneg (lambda (x) -1)))
      (lift1 - negate Dneg)))
  
  (define sub
    (let ((dsubdx (lambda (x y) 1))
          (dsubdy (lambda (x y) -1)))
      (lift2 - sub dsubdx dsubdy)))
  
  (define mul
    (let ((dmuldx (lambda (x y) y))
          (dmuldy (lambda (x y) x)))
      (lift2 * mul dmuldx dmuldy)))
  
  (define invert
    (let ((Dinvert (lambda (x)
                     (invert (negate (mul x x))))))
      (lift1 / invert Dinvert)))
  
  (define div
    (let ((ddivdx (lambda (x y) (invert y)))
          (ddivdy (lambda (x y) (negate (div x (mul y y))))))
      (lift2 / div ddivdx ddivdy)))
  
  (define-syntax lift-comp
    (syntax-rules ()
      ((lift-comp ndcomp comp)
       (lambda (xx yy)
         (match xx
           ((struct deriv (_ x _))
            (comp x yy))
           (x (match yy
                ((struct deriv (_ y _))
                 (comp x y))
                (y (ndcomp x y)))))))))
  
  (define =2 (lift-comp = =2))
  (define my-=
    (case-lambda
      ((x y) (=2 x y))
      ((x y . zs)
       (and (=2 x y) (apply my-= y zs)))))
  
  (define <2 (lift-comp < <2))
  (define my-<
    (case-lambda
      ((x y) (<2 x y))
      ((x y . zs)
       (and (<2 x y) (apply my-< y zs)))))
  
  (define >2 (lift-comp > >2))
  (define my->
    (case-lambda
      ((x y) (>2 x y))
      ((x y . zs) (and (>2 x y) (apply my-> y zs)))))
  
  (define <=2 (lift-comp <= <=2))
  (define my-<=
    (case-lambda
      ((x y) (<=2 x y))
      ((x y . zs) (and (<=2 x y) (apply my-<= y zs)))))
  
  (define >=2 (lift-comp >= >=2))
  (define my->=
    (case-lambda
      ((x y) (>=2 x y))
      ((x y . zs) (and (>=2 x y) (apply my->= y zs)))))
  
  (define-syntax lift-pred
    (syntax-rules ()
      ((lift-pred ndp p)
       (lambda (x)
         (match x
           ((struct deriv (_ x _))
            (p x))
           (x (ndp x)))))))
  
  (define my-zero? (lift-pred zero? my-zero?))
  (define my-positive? (lift-pred positive? my-positive?))
  (define my-negative? (lift-pred negative? my-negative?))
  (define my-odd? (lift-pred odd? my-odd?))
  (define my-even? (lift-pred even? my-even?))
  
  (define my-abs
    (let ((Dabs (lambda (x)
                  (when (my-zero? x)
                    (raise-mismatch-error 'abs "no derivative at zero" x))
                  (if (my-negative? x)
                      -1
                      1))))
      (lift1 abs my-abs Dabs)))
  
  (define my-exp (lift1 exp my-exp my-exp))
  (define my-log (lift1 log my-log invert))
  (define my-sin (lift1 sin my-sin my-cos))
  (define my-cos
    (let ((Dcos (lambda (x) (negate (my-sin x)))))
      (lift1 cos my-cos Dcos)))
  (define my-tan
    (let ((Dtan (lambda (x)
                  (let ((cos-x (my-cos x)))
                    (invert (mul cos-x cos-x))))))
      (lift1 tan my-tan Dtan)))
  (define my-asin
    (let ((Dasin (lambda (x) (invert (my-sqrt (sub 1 (mul x x)))))))
      (lift1 asin my-asin Dasin)))
  (define my-acos
    (let ((Dacos (lambda (x) (negate (invert (my-sqrt (sub 1 (mul x x))))))))
      (lift1 acos my-acos Dacos)))
  
  (define atan1
    (let ((Datan1 (lambda (x) (invert (add 1 (mul x x))))))
      (lift1 atan atan1 Datan1)))
  (define atan2
    (let ((datan2dy (lambda (y x) (div x (add (mul x x) (mul y y)))))
          (datan2dx (lambda (y x) (negate (div y (add (mul x x) (mul y y)))))))
      (lift2 atan atan2 datan2dy datan2dx)))
  (define my-atan
    (case-lambda
      ((x) (atan1 x))
      ((y x) (atan2 y x))))
  
  (define my-sqrt
    (let ((Dsqrt (lambda (x) (div 1/2 (my-sqrt x)))))
      (lift1 sqrt my-sqrt Dsqrt)))
  
  (define my-expt
    (let ((dexptdx (lambda (x y) (mul y (my-expt x (sub y 1)))))
          (dexptdy (lambda (x y) (mul (my-log x) (my-expt x y)))))
      (lift2 expt my-expt dexptdx dexptdy))))