matrix.ss
#lang scheme

#|  matrix.ss: Matrix operations and datastructures.
    Copyright (C) 2008 Will M. Farr <farr@mit.edu>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
|#

(require "matrix-base.ss")

(provide (all-from-out "matrix-base.ss"))

(provide/contract
 
 (matrix-add (->d ((m1 matrix?)
                   (m2 (matrix-same-dimensions/c m1)))
                  ()
                  (_ (matrix-same-dimensions/c m1))))
 (matrix-sub (->d ((m1 matrix?)
                   (m2 (matrix-same-dimensions/c m1)))
                  ()
                  (_ (matrix-same-dimensions/c m1))))
 (matrix-scale (->d ((m matrix?)
                     (s number?))
                    ()
                    (_ (matrix-same-dimensions/c m))))
 (matrix-mul (->d ((m1 matrix?)
                   (m2 (matrix-mul-compatible/c m1)))
                  ()
                  (_ (matrix-mul-result/c m1 m2))))
 (matrix-vector-mul (->d ((m matrix?)
                          (v (and/c (vectorof number?)
                                    (vector-of-length/c (matrix-cols m)))))
                         ()
                         (_ (and/c (vectorof number?)
                                   (vector-of-length/c (matrix-rows m))))))
 (vector-matrix-mul (->d ((v (vectorof number?))
                          (m (matrix-of-rows/c (vector-length v))))
                         ()
                         (_ (and/c (vectorof number?)
                                   (vector-of-length/c (matrix-cols m))))))
 (vector-add (->d ((v1 (vectorof number?))
                   (v2 (and/c (vectorof number?)
                              (vector-of-length/c (vector-length v1)))))
                  ()
                  (_ (vector-of-length/c (vector-length v1)))))
 (vector-sub (->d ((v1 (vectorof number?))
                   (v2 (and/c (vectorof number?)
                              (vector-of-length/c (vector-length v1)))))
                  ()
                  (_ (vector-of-length/c (vector-length v1)))))
 (vector-scale (->d ((v (vectorof number?))
                     (s number?))
                    ()
                    (_ (vector-of-length/c (vector-length v)))))
 (vector-dot (->d ((v1 (vectorof number?))
                   (v2 (and/c (vectorof number?)
                              (vector-of-length/c (vector-length v1)))))
                  ()
                  (_ number?)))
 (matrix-identity (->d ((n natural-number/c))
                       ()
                       (_ (matrix-of-dimensions/c n n))))
 (matrix-trace (-> matrix? number?)))

(define (matrix-add m1 m2)
  (for/matrix (matrix-rows m1) (matrix-cols m1)
    ((x (in-matrix m1))
     (y (in-matrix m2)))
    (+ x y)))

(define (matrix-sub m1 m2)
  (for/matrix (matrix-rows m1) (matrix-cols m1)
    ((x (in-matrix m1))
     (y (in-matrix m2)))
    (- x y)))

(define (matrix-scale m s)
  (for/matrix (matrix-rows m) (matrix-cols m)
    ((x (in-matrix m)))
    (* s x)))

(define (matrix-mul m1 m2)
  (let ((m (matrix-rows m1))
        (n (matrix-cols m2)))
    (for*/matrix m n
      ((i (in-range m))
       (j (in-range n)))
      (for/fold ((sum 0))
        ((k (in-range (matrix-cols m1))))
        (+ sum (* (matrix-ref m1 i k) (matrix-ref m2 k j)))))))

(define (matrix-vector-mul m v)
  (let ((r (matrix-rows m)))
    (for/vector r
      ((i (in-range r)))
      (for/fold ((sum 0))
        ((j (in-range (vector-length v))))
        (+ sum (* (matrix-ref m i j) (vector-ref v j)))))))

(define (vector-matrix-mul v m)
  (let ((c (matrix-cols m)))
    (for/vector c
      ((j (in-range c)))
      (for/fold ((sum 0))
        ((i (in-range (vector-length v))))
        (+ sum (* (matrix-ref m i j) (vector-ref v i)))))))

(define (vector-add v1 v2)
  (for/vector (vector-length v1)
    ((x (in-vector v1))
     (y (in-vector v2)))
    (+ x y)))

(define (vector-sub v1 v2)
  (for/vector (vector-length v1)
    ((x (in-vector v1))
     (y (in-vector v2)))
    (- x y)))

(define (vector-scale v s)
  (for/vector (vector-length v)
    ((x (in-vector v)))
    (* x s)))

(define (vector-dot v1 v2)
  (for/fold ((sum 0))
    ((x (in-vector v1))
     (y (in-vector v2)))
    (+ sum (* x y))))

(define (matrix-identity n)
  (for*/matrix n n 
    ((i (in-range n))
     (j (in-range n)))
    (if (= i j) 1 0)))

(define (matrix-trace m)
  (for/fold ((sum 0))
    ((i (in-range (min (matrix-rows m) (matrix-cols m)))))
    (+ sum (matrix-ref m i i))))