#lang scheme
(require "matrix-base.ss")
(provide (all-from-out "matrix-base.ss"))
(provide/contract
(matrix-add (->d ((m1 matrix?)
(m2 (matrix-same-dimensions/c m1)))
()
(_ (matrix-same-dimensions/c m1))))
(matrix-sub (->d ((m1 matrix?)
(m2 (matrix-same-dimensions/c m1)))
()
(_ (matrix-same-dimensions/c m1))))
(matrix-scale (->d ((m matrix?)
(s number?))
()
(_ (matrix-same-dimensions/c m))))
(matrix-mul (->d ((m1 matrix?)
(m2 (matrix-mul-compatible/c m1)))
()
(_ (matrix-mul-result/c m1 m2))))
(matrix-vector-mul (->d ((m matrix?)
(v (and/c (vectorof number?)
(vector-of-length/c (matrix-cols m)))))
()
(_ (and/c (vectorof number?)
(vector-of-length/c (matrix-rows m))))))
(vector-matrix-mul (->d ((v (vectorof number?))
(m (matrix-of-rows/c (vector-length v))))
()
(_ (and/c (vectorof number?)
(vector-of-length/c (matrix-cols m))))))
(vector-add (->d ((v1 (vectorof number?))
(v2 (and/c (vectorof number?)
(vector-of-length/c (vector-length v1)))))
()
(_ (vector-of-length/c (vector-length v1)))))
(vector-sub (->d ((v1 (vectorof number?))
(v2 (and/c (vectorof number?)
(vector-of-length/c (vector-length v1)))))
()
(_ (vector-of-length/c (vector-length v1)))))
(vector-scale (->d ((v (vectorof number?))
(s number?))
()
(_ (vector-of-length/c (vector-length v)))))
(vector-dot (->d ((v1 (vectorof number?))
(v2 (and/c (vectorof number?)
(vector-of-length/c (vector-length v1)))))
()
(_ number?)))
(matrix-identity (->d ((n natural-number/c))
()
(_ (matrix-of-dimensions/c n n))))
(matrix-trace (-> matrix? number?)))
(define (matrix-add m1 m2)
(for/matrix (matrix-rows m1) (matrix-cols m1)
((x (in-matrix m1))
(y (in-matrix m2)))
(+ x y)))
(define (matrix-sub m1 m2)
(for/matrix (matrix-rows m1) (matrix-cols m1)
((x (in-matrix m1))
(y (in-matrix m2)))
(- x y)))
(define (matrix-scale m s)
(for/matrix (matrix-rows m) (matrix-cols m)
((x (in-matrix m)))
(* s x)))
(define (matrix-mul m1 m2)
(let ((m (matrix-rows m1))
(n (matrix-cols m2)))
(for*/matrix m n
((i (in-range m))
(j (in-range n)))
(for/fold ((sum 0))
((k (in-range (matrix-cols m1))))
(+ sum (* (matrix-ref m1 i k) (matrix-ref m2 k j)))))))
(define (matrix-vector-mul m v)
(let ((r (matrix-rows m)))
(for/vector r
((i (in-range r)))
(for/fold ((sum 0))
((j (in-range (vector-length v))))
(+ sum (* (matrix-ref m i j) (vector-ref v j)))))))
(define (vector-matrix-mul v m)
(let ((c (matrix-cols m)))
(for/vector c
((j (in-range c)))
(for/fold ((sum 0))
((i (in-range (vector-length v))))
(+ sum (* (matrix-ref m i j) (vector-ref v i)))))))
(define (vector-add v1 v2)
(for/vector (vector-length v1)
((x (in-vector v1))
(y (in-vector v2)))
(+ x y)))
(define (vector-sub v1 v2)
(for/vector (vector-length v1)
((x (in-vector v1))
(y (in-vector v2)))
(- x y)))
(define (vector-scale v s)
(for/vector (vector-length v)
((x (in-vector v)))
(* x s)))
(define (vector-dot v1 v2)
(for/fold ((sum 0))
((x (in-vector v1))
(y (in-vector v2)))
(+ sum (* x y))))
(define (matrix-identity n)
(for*/matrix n n
((i (in-range n))
(j (in-range n)))
(if (= i j) 1 0)))
(define (matrix-trace m)
(for/fold ((sum 0))
((i (in-range (min (matrix-rows m) (matrix-cols m)))))
(+ sum (matrix-ref m i i))))