matrix-lang.ss
#lang scheme

#|  matrix-lang.ss: Language module for matrix and vector operations.
    Copyright (C) 2008 Will M. Farr <farr@mit.edu>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
|#

(require "matrix.ss")

(provide (except-out (all-from-out scheme) + - * /)
         (rename-out (my-+ +) (my-- -) (my-* *) (my-/ /))
         (all-from-out "matrix.ss"))

(define-syntax define-number-vector-matrix
  (syntax-rules ()
    ((define-number-vector-matrix (f x y)
       nn-expr nv-expr nm-expr vn-expr vv-expr vm-expr mn-expr mv-expr mm-expr)
     (define (f x y)
       (cond
         ((number? x)
          (cond 
            ((number? y) nn-expr)
            ((vector? y) nv-expr)
            ((matrix? y) nm-expr)
            (else (error 'f "second argument not a number, vector or matrix" y))))
         ((vector? x)
          (cond
            ((vector? y) vv-expr)
            ((number? y) vn-expr)
            ((matrix? y) vm-expr)
            (else (error 'f "second argument not a number, vector or matrix" y))))
         ((matrix? x)
          (cond
            ((matrix? y) mm-expr)
            ((number? y) mn-expr)
            ((vector? y) mv-expr)
            (else (error 'f "second argument not a number, vector or matrix" y))))
         (else (error 'f "first argument not a number, vector, or matrix" x)))))))

(define-number-vector-matrix (binary-add x y)
  (+ x y)
  (error '+ "cannot add number and vector" x y)
  (error '+ "cannot add number and matrix" x y)
  (error '+ "cannot add vector and number" x y)
  (vector-add x y)
  (error '+ "cannot add vector and matrix" x y)
  (error '+ "cannot add matrix and number" x y)
  (error '+ "cannot add matrix and vector" x y)
  (matrix-add x y))

(define my-+
  (case-lambda
    (() 0)
    ((x) x)
    ((x y) (binary-add x y))
    ((x y . zs) (apply my-+ (binary-add x y) zs))))

(define my--
  (case-lambda
    ((x) (unary-minus x))
    ((x y) (binary-minus x y))
    ((x . ys) (binary-minus x (apply my-+ ys)))))

(define (unary-minus x)
  (cond
    ((number? x) (- x))
    ((vector? x) (vector-scale x -1))
    ((matrix? x) (matrix-scale x -1))
    (else (error '- "argument must be number, vector, or matrix" x))))

(define-number-vector-matrix (binary-minus x y)
  (- x y)
  (error '- "cannot sub number and vector" x y)
  (error '- "cannot sub number and matrix" x y)
  (error '- "cannot sub vector and number" x y)
  (vector-sub x y)
  (error '- "cannot sub vector and matrix" x y)
  (error '- "cannot sub matrix and number" x y)
  (error '- "cannot sub matrix and vector" x y)
  (matrix-sub x y))

(define my-*
  (case-lambda
    (() 1)
    ((x) x)
    ((x y) (*2 x y))
    ((x y . zs) (apply my-* (*2 x y) zs))))

(define-number-vector-matrix (*2 x y)
  (* x y)
  (vector-scale y x)
  (matrix-scale y x)
  (vector-scale x y)
  (vector-dot x y)
  (vector-matrix-mul x y)
  (matrix-scale x y)
  (matrix-vector-mul x y)
  (matrix-mul x y))

(define my-/
  (case-lambda
    ((x) (/1 x))
    ((x y) (/2 x y))
    ((x . ys) (/2 x (apply my-* ys)))))

(define (/1 x)
  (cond
    ((number? x) (/ x))
    (else (error '/ "single argument must be a number" x))))

(define-number-vector-matrix (/2 x y)
  (/ x y)
  (error '/ "cannot divide number by vector" x y)
  (error '/ "cannot divide number by matrix" x y)
  (vector-scale x (/ y))
  (error '/ "cannot divide vector by vector" x y)
  (error '/ "cannot divide vector by matrix" x y)
  (matrix-scale x (/ y))
  (error '/ "cannot divide matrix by vector" x y)
  (error '/ "cannot divide matrix by matrix" x y))