matrix.ss
#lang scheme

#|  matrix.ss: Matrix operations and datastructures.
    Copyright (C) 2008 Will M. Farr <farr@mit.edu>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
|#

(define-struct matrix
  (rows cols elts)
  #:transparent
  #:property prop:sequence (lambda (m)
                             (make-do-sequence
                              (lambda ()
                                (values 
                                 (lambda (k)
                                   (vector-ref (matrix-elts m) k))
                                 add1
                                 0
                                 (lambda (k) (< k (* (matrix-rows m) (matrix-cols m))))
                                 (lambda (elt) #t)
                                 (lambda (k elt) #t))))))

(define (matrix-of-dimensions/c r c)
  (flat-named-contract
   (format "<~a by ~a matrix>" r c)
   (lambda (obj)
     (and (matrix? obj)
          (= (matrix-rows obj) r)
          (= (matrix-cols obj) c)))))

(define (matrix-of-rows/c r)
  (flat-named-contract 
   (format "<~a by <any> matrix>" r)
   (lambda (obj)
     (and (matrix? obj)
          (= (matrix-rows obj) r)))))

(define (matrix-of-cols/c c)
  (flat-named-contract
   (format "<<any> by ~a matrix>" c)
   (lambda (obj)
     (and (matrix? obj)
          (= (matrix-cols obj) c)))))

(define (vector-of-length/c l)
  (flat-named-contract
   (format "<vector of length ~a>" l)
   (lambda (obj)
     (and (vector? obj)
          (= (vector-length obj) l)))))

(define (matrix-same-dimensions/c m)
  (matrix-of-dimensions/c (matrix-rows m) (matrix-cols m)))

(define (matrix-mul-compatible/c m)
  (matrix-of-rows/c (matrix-cols m)))

(define (matrix-mul-result/c m1 m2)
  (matrix-of-dimensions/c (matrix-rows m1) (matrix-cols m2)))

(define (list-of-length/c l)
  (flat-named-contract
   (format "<list-of-length ~a>" l)
   (lambda (obj)
     (and (list? obj)
          (= (length obj) l)))))

(provide in-matrix for/vector for*/vector for/matrix for*/matrix 
         matrix-of-dimensions/c matrix-of-rows/c matrix-of-cols/c
         vector-of-length/c 
         matrix-same-dimensions/c
         matrix-mul-compatible/c
         matrix-mul-result/c
         list-of-length/c)

(provide/contract
 (struct matrix
   ((rows natural-number/c)
    (cols natural-number/c)
    (elts (vectorof number?))))
 (matrix* (->d ((rows natural-number/c)
                (cols natural-number/c))
               ()
               #:rest nums (and/c (listof number?) (list-of-length/c (* rows cols)))
               (_ matrix?)))
 (matrix-ref (->d ((m matrix?)
                   (i (and/c natural-number/c 
                             (</c (matrix-rows m))))
                   (j (and/c natural-number/c 
                             (</c (matrix-cols m)))))
                  ()
                  (_ number?)))
 (matrix-set! (->d ((m matrix?)
                    (i (and/c natural-number/c 
                              (</c (matrix-rows m))))
                    (j (and/c natural-number/c 
                              (</c (matrix-cols m))))
                    (x number?))
                   ()
                   any))
 (matrix-add (->d ((m1 matrix?)
                   (m2 (matrix-same-dimensions/c m1)))
                  ()
                  (_ (matrix-same-dimensions/c m1))))
 (matrix-sub (->d ((m1 matrix?)
                   (m2 (matrix-same-dimensions/c m1)))
                  ()
                  (_ (matrix-same-dimensions/c m1))))
 (matrix-scale (->d ((m matrix?)
                     (s number?))
                    ()
                    (_ (matrix-same-dimensions/c m))))
 (matrix-mul (->d ((m1 matrix?)
                   (m2 (matrix-mul-compatible/c m1)))
                  ()
                  (_ (matrix-mul-result/c m1 m2))))
 (matrix-vector-mul (->d ((m matrix?)
                          (v (and/c (vectorof number?)
                                    (vector-of-length/c (matrix-cols m)))))
                         ()
                         (_ (and/c (vectorof number?)
                                   (vector-of-length/c (matrix-rows m))))))
 (vector-matrix-mul (->d ((v (vectorof number?))
                          (m (matrix-of-rows/c (vector-length v))))
                         ()
                         (_ (and/c (vectorof number?)
                                   (vector-of-length/c (matrix-cols m))))))
 (vector-add (->d ((v1 (vectorof number?))
                   (v2 (and/c (vectorof number?)
                              (vector-of-length/c (vector-length v1)))))
                  ()
                  (_ (vector-of-length/c (vector-length v1)))))
 (vector-sub (->d ((v1 (vectorof number?))
                   (v2 (and/c (vectorof number?)
                              (vector-of-length/c (vector-length v1)))))
                  ()
                  (_ (vector-of-length/c (vector-length v1)))))
 (vector-scale (->d ((v (vectorof number?))
                     (s number?))
                    ()
                    (_ (vector-of-length/c (vector-length v)))))
 (vector-dot (->d ((v1 (vectorof number?))
                   (v2 (and/c (vectorof number?)
                              (vector-of-length/c (vector-length v1)))))
                  ()
                  (_ number?)))
 (matrix-transpose (->d ((m matrix?))
                        ()
                        (_ (matrix-of-dimensions/c (matrix-cols m) (matrix-rows m)))))
 (matrix-identity (->d ((n natural-number/c))
                       ()
                       (_ (matrix-of-dimensions/c n n)))))

(define (matrix* m n . nums)
  (make-matrix m n (list->vector nums)))

(define (matrix-ref m i j)
  (vector-ref (matrix-elts m) (+ (* (matrix-cols m) i) j)))

(define (matrix-set! m i j x)
  (vector-set! (matrix-elts m) (+ (* (matrix-cols m) i) j) x))

(define (*in-matrix m)
  (make-do-sequence
   (lambda ()
     (values 
      (lambda (k)
        (vector-ref (matrix-elts m) k))
      add1
      0
      (lambda (k) (< k (* (matrix-rows m) (matrix-cols m))))
      (lambda (elt) #t)
      (lambda (k elt) #t)))))

(define-sequence-syntax in-matrix
  (lambda () (syntax *in-matrix))
  (lambda (stx)
    (syntax-case stx ()
      (((id) (_ matrix-expr))
       (syntax/loc stx
         ((id) (in-vector (matrix-elts matrix-expr)))))
      (((i-id j-id elt-id) (_ matrix-expr))
       (syntax/loc stx
         ((i-id j-id elt-id)
          (:do-in 
           (((m) matrix-expr)) ;Outer bindings.
           #t ; Outer check
           ((i-id 0) (j-id 0) (rows (matrix-rows m)) (cols (matrix-cols m))) ;Loop id
           (< i-id rows) ; Pos-guard
           (((ip1) (add1 i-id)) ((jp1) (add1 j-id)) ((elt-id) (matrix-ref m i-id j-id))) ; Inner-id.
           #t ; Pre-guard
           #t ; Post-guard
           ((if (>= jp1 cols) ip1 i-id) (if (>= jp1 cols) 0 jp1) rows cols))))))))

(define-syntax (for/vector stx)
  (syntax-case stx ()
    ((for/vector length-expr (for-clause ...) body)
     (syntax/loc stx
       (let ((length length-expr))
         (for/fold/derived stx ((result (make-vector length)))
           ((i (in-naturals))
            for-clause ...)
           (let ()
             (vector-set! result i body)
             result)))))))

(define-syntax (for*/vector stx)
  (syntax-case stx ()
    ((for*/vector length-expr (for-clause ...) body)
     (syntax/loc stx
       (let ((length length-expr)
             (i 0))
         (for*/fold/derived stx ((result (make-vector length)))
           (for-clause ...)
           (let ()
             (vector-set! result i body)
             (set! i (add1 i))
             result)))))))

(define-syntax (for/matrix stx)
  (syntax-case stx ()
    ((for/matrix rows-expr cols-expr (for-clause ...) body)
     (syntax/loc stx
       (let ((rows rows-expr)
             (cols cols-expr))
         (make-matrix rows cols (for/vector (* rows cols) (for-clause ...) body)))))))

(define-syntax (for*/matrix stx)
  (syntax-case stx ()
    ((for*/matrix rows-expr cols-expr (for-clause ...) body)
     (syntax/loc stx
       (let ((rows rows-expr)
             (cols cols-expr))
         (make-matrix rows cols (for*/vector (* rows cols) (for-clause ...) body)))))))

(define (matrix-add m1 m2)
  (for/matrix (matrix-rows m1) (matrix-cols m1)
      ((x (in-matrix m1))
       (y (in-matrix m2)))
      (+ x y)))

(define (matrix-sub m1 m2)
  (for/matrix (matrix-rows m1) (matrix-cols m1)
    ((x (in-matrix m1))
     (y (in-matrix m2)))
    (- x y)))

(define (matrix-scale m s)
  (for/matrix (matrix-rows m) (matrix-cols m)
    ((x (in-matrix m)))
    (* s x)))

(define (matrix-mul m1 m2)
  (let ((m (matrix-rows m1))
        (n (matrix-cols m2)))
    (for*/matrix m n
      ((i (in-range m))
       (j (in-range n)))
      (for/fold ((sum 0))
        ((k (in-range (matrix-cols m1))))
        (+ sum (* (matrix-ref m1 i k) (matrix-ref m2 k j)))))))

(define (matrix-vector-mul m v)
  (let ((r (matrix-rows m)))
    (for/vector r
      ((i (in-range r)))
      (for/fold ((sum 0))
        ((j (in-range (vector-length v))))
        (+ sum (* (matrix-ref m i j) (vector-ref v j)))))))

(define (vector-matrix-mul v m)
  (let ((c (matrix-cols m)))
    (for/vector c
      ((j (in-range c)))
      (for/fold ((sum 0))
        ((i (in-range (vector-length v))))
        (+ sum (* (matrix-ref m i j) (vector-ref v i)))))))

(define (vector-add v1 v2)
  (for/vector (vector-length v1)
    ((x (in-vector v1))
     (y (in-vector v2)))
    (+ x y)))

(define (vector-sub v1 v2)
  (for/vector (vector-length v1)
    ((x (in-vector v1))
     (y (in-vector v2)))
    (- x y)))

(define (vector-scale v s)
  (for/vector (vector-length v)
    ((x (in-vector v)))
    (* x s)))

(define (vector-dot v1 v2)
  (for/fold ((sum 0))
    ((x (in-vector v1))
     (y (in-vector v2)))
    (+ sum (* x y))))

(define (matrix-transpose m)
  (let ((r (matrix-rows m))
        (c (matrix-cols m)))
    (for*/matrix c r
      ((i (in-range c))
       (j (in-range r)))
      (matrix-ref m j i))))

(define (matrix-identity n)
  (for*/matrix n n 
    ((i (in-range n))
     (j (in-range n)))
    (if (= i j) 1 0)))