New patches: [unrevert anonymous**20071205171317] < > { hunk ./all.ss 1 +#lang scheme/base + #| all.ss: Export all procedures from the plt-linalg library. Copyright (C) 2007 Will M. Farr hunk ./all.ss 22 |# -(module all mzscheme - (require "matrix.ss" "vector.ss") +(require "matrix.ss" "vector.ss") hunk ./all.ss 24 - (provide (all-from "matrix.ss") - (all-from "vector.ss"))) +(provide (all-from-out "matrix.ss") + (all-from-out "vector.ss")) hunk ./blas-lapack.ss 1 +#lang scheme/base + #| blas-lapack.ss: Library locations for BLAS/LAPACK. Copyright (C) 2007 Will M. Farr hunk ./blas-lapack.ss 24 #| Special thanks to Noel Welsh, who contributed the library-searching code below. |# -(module blas-lapack mzscheme - (require (lib "foreign.ss") - (lib "list.ss" "srfi" "1")) - - (unsafe!) - - (define-unsafer blas-lapack-unsafe!) - - (provide *blas* *lapack* - _cblas-order _cblas-transpose _cblas-uplo - _cblas-diag _cblas-side) - - ;; search-paths : (listof string) - (define search-paths - (case (system-type) - [(macosx) - (list - "/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Versions/Current" - "/System/Library/Frameworks/vecLib.framework/Versions/Current")] - [(unix) - (list - "/usr/lib" - "/home/pg/nhw/data/lib" ;;;; NB: NHW specific - )])) - - (define default-path "") - - ;; lib-blas : (listof string) - ;; - ;; Possible names for the BLAS library - (define lib-blas - (case (system-type) - [(macosx) '("libBLAS")] - [(unix) '("libcblas" "libgslcblas")])) - - ;; lib-lapack : (listof string) - ;; - ;; Possible names for the LAPACK library - (define lib-lapack - (case (system-type) - [(macosx) '("libLAPACK")] - [(unix) '("liblapack")])) - - (define (string-empty? s) - (= (string-length s) 0)) - - ;; base-paths : (listof (U path string)) - (define base-paths - (filter (lambda (path) - (and (not (string-empty? path)) - (directory-exists? path))) - (append search-paths (list default-path)))) - - (define (build-path* . paths-or-empty) - (apply build-path (filter (lambda (p-or-e) (not (string-empty? - p-or-e))) paths-or-empty))) - - (define (find-libraries-that-exist names search-paths) - (let ([found - (remove not - (append-map - (lambda (name) - (map (lambda (search-path) - (with-handlers - ([exn? (lambda (exn) #f)]) - (ffi-lib (build-path* search-path name)))) - search-paths)) - names))]) - (if (null? found) - (error - (format "Could not find any of ~a under paths ~a~n" names - search-paths)) - found))) - - (define *blas* (car (find-libraries-that-exist lib-blas base-paths))) - (define *lapack* (car (find-libraries-that-exist lib-lapack base-paths))) - - (define _cblas-order (_enum '(row-major = 101 col-major = 102))) - (define _cblas-transpose (_enum '(no-trans = 111 trans = 112 conj-trans = 113 atlas-conj = 114))) - (define _cblas-uplo (_enum '(upper = 121 lower = 122))) - (define _cblas-diag (_enum '(non-unit = 131 unit = 132))) - (define _cblas-side (_enum '(left = 141 right = 142))) - - (define-for-syntax (append-to-syntax-object stx . objs) - (let ((strings (map (lambda (obj) - (cond - ((symbol? obj) (symbol->string obj)) - ((syntax? obj) (symbol->string (syntax-object->datum obj))) - (else obj))) - objs))) - (datum->syntax-object stx - (string->symbol - (apply string-append strings))))) - - (define-syntax (define-blas stx) - (syntax-case stx () - ((define-blas name args ...) - (with-syntax ((_TAGvector (datum->syntax-object stx '_TAGvector)) - (TAGvector-length (datum->syntax-object stx 'TAGvector-length)) - (_type (datum->syntax-object stx '_type)) - (sname (append-to-syntax-object stx 's (syntax name))) - (cblas_sname (append-to-syntax-object stx 'cblas_s (syntax name))) - (dname (append-to-syntax-object stx 'd (syntax name))) - (cblas_dname (append-to-syntax-object stx 'cblas_d (syntax name)))) - (syntax/loc stx - (begin - (provide* (unsafe sname)) - (define sname - (let ((_TAGvector _f32vector) - (TAGvector-length f32vector-length) - (_type _float)) - (get-ffi-obj 'cblas_sname *blas* - (_fun args ...) - (lambda () - (lambda x - (error 'blas (string-append "function " - (symbol->string 'cblas_sname) - " not found in blas library."))))))) - (provide* (unsafe dname)) - (define dname - (let ((_TAGvector _f64vector) - (TAGvector-length f64vector-length) - (_type _double*)) - (get-ffi-obj 'cblas_dname *blas* - (_fun args ...) - (lambda () - (lambda x - (error 'blas (string-append "function " - (symbol->string 'cblas_dname) - " not found in blas library."))))))))))))) - - (define-blas dot _int _TAGvector _int _TAGvector _int -> _type) - (define-blas nrm2 _int _TAGvector _int -> _type) - (define-blas asum _int _TAGvector _int -> _type) - (define-blas swap _int _TAGvector _int _TAGvector _int -> _void) - (define-blas copy _int _TAGvector _int _TAGvector _int -> _void) - (define-blas axpy _int _type _TAGvector _int _TAGvector _int -> _void) - (define-blas scal _int _type _TAGvector _int -> _void) - (define-blas gemv _cblas-order _cblas-transpose _int _int _type _TAGvector _int _TAGvector _int _type _TAGvector _int -> _void) - (define-blas gbmv _cblas-order _cblas-transpose _int _int _int _int _type _TAGvector _int _TAGvector _int _type _TAGvector _int -> _void) - (define-blas trmv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _TAGvector _int _TAGvector _int -> _void) - (define-blas tbmv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _int _TAGvector _int _TAGvector _int -> _void) - (define-blas tpmv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _TAGvector _TAGvector _int -> _void) - (define-blas trsv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _TAGvector _int _TAGvector _int -> _void) - (define-blas tbsv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _int _TAGvector _int _TAGvector _int -> _void) - (define-blas tpsv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _TAGvector _TAGvector _int -> _void) - (define-blas symv _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector _int _type _TAGvector _int -> _void) - (define-blas sbmv _cblas-order _cblas-uplo _int _int _type _TAGvector _int _TAGvector _int _type _TAGvector _int -> _void) - (define-blas spmv _cblas-order _cblas-uplo _int _type _TAGvector _TAGvector _int _type _TAGvector _int -> _void) - (define-blas ger _cblas-order _int _int _type _TAGvector _int _TAGvector _int _TAGvector _int -> _void) - (define-blas syr _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector _int -> _void) - (define-blas spr _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector -> _void) - (define-blas syr2 _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector _int _TAGvector _int -> _void) - (define-blas spr2 _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector _int _TAGvector -> _void) - (define-blas gemm _cblas-order _cblas-transpose _cblas-transpose _int _int _type _TAGvector _int - _TAGvector _int _type _TAGvector _int -> _void) - (define-blas symm _cblas-order _cblas-side _cblas-uplo _int _int _type _TAGvector _int - _TAGvector _int _type _TAGvector _int -> _void) - (define-blas syrk _cblas-order _cblas-uplo _cblas-transpose _int _int _type - _TAGvector _int _type _TAGvector _int -> _void) - (define-blas syr2k _cblas-order _cblas-uplo _cblas-transpose _int _int _type _TAGvector - _int _TAGvector _int _type _TAGvector _int -> _void) - (define-blas trmm _cblas-order _cblas-side _cblas-uplo _cblas-transpose _cblas-diag _int _int - _type _TAGvector _int _TAGvector _int -> _void) - (define-blas trsm _cblas-order _cblas-side _cblas-uplo _cblas-transpose _cblas-diag _int _int - _type _TAGvector _int _TAGvector _int -> _void)) +(require (lib "foreign.ss") + (except-in srfi/1/list remove filter)) + +(unsafe!) + +(define-unsafer blas-lapack-unsafe!) + +(provide *blas* *lapack* + _cblas-order _cblas-transpose _cblas-uplo + _cblas-diag _cblas-side) + +;; search-paths : (listof string) +(define search-paths + (case (system-type) + [(macosx) + (list + "/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Versions/Current" + "/System/Library/Frameworks/vecLib.framework/Versions/Current")] + [(unix) + (list + "/usr/lib" + "/home/pg/nhw/data/lib" ;;;; NB: NHW specific + )])) + +(define default-path "") + +;; lib-blas : (listof string) +;; +;; Possible names for the BLAS library +(define lib-blas + (case (system-type) + [(macosx) '("libBLAS")] + [(unix) '("libcblas" "libgslcblas")])) + +;; lib-lapack : (listof string) +;; +;; Possible names for the LAPACK library +(define lib-lapack + (case (system-type) + [(macosx) '("libLAPACK")] + [(unix) '("liblapack")])) + +(define (string-empty? s) + (= (string-length s) 0)) + +;; base-paths : (listof (U path string)) +(define base-paths + (filter (lambda (path) + (and (not (string-empty? path)) + (directory-exists? path))) + (append search-paths (list default-path)))) + +(define (build-path* . paths-or-empty) + (apply build-path (filter (lambda (p-or-e) (not (string-empty? + p-or-e))) paths-or-empty))) + +(define (find-libraries-that-exist names search-paths) + (let ([found + (remove not + (append-map + (lambda (name) + (map (lambda (search-path) + (with-handlers + ([exn? (lambda (exn) #f)]) + (ffi-lib (build-path* search-path name)))) + search-paths)) + names))]) + (if (null? found) + (error + (format "Could not find any of ~a under paths ~a~n" names + search-paths)) + found))) + +(define *blas* (car (find-libraries-that-exist lib-blas base-paths))) +(define *lapack* (car (find-libraries-that-exist lib-lapack base-paths))) + +(define _cblas-order (_enum '(row-major = 101 col-major = 102))) +(define _cblas-transpose (_enum '(no-trans = 111 trans = 112 conj-trans = 113 atlas-conj = 114))) +(define _cblas-uplo (_enum '(upper = 121 lower = 122))) +(define _cblas-diag (_enum '(non-unit = 131 unit = 132))) +(define _cblas-side (_enum '(left = 141 right = 142))) + +(define-for-syntax (append-to-syntax-object stx . objs) + (let ((strings (map (lambda (obj) + (cond + ((symbol? obj) (symbol->string obj)) + ((syntax? obj) (symbol->string (syntax->datum obj))) + (else obj))) + objs))) + (datum->syntax stx + (string->symbol + (apply string-append strings))))) + +(define-syntax (define-blas stx) + (syntax-case stx () + ((define-blas name args ...) + (with-syntax ((_TAGvector (datum->syntax-object stx '_TAGvector)) + (TAGvector-length (datum->syntax-object stx 'TAGvector-length)) + (_type (datum->syntax-object stx '_type)) + (sname (append-to-syntax-object stx 's (syntax name))) + (cblas_sname (append-to-syntax-object stx 'cblas_s (syntax name))) + (dname (append-to-syntax-object stx 'd (syntax name))) + (cblas_dname (append-to-syntax-object stx 'cblas_d (syntax name)))) + (syntax/loc stx + (begin + (provide* (unsafe sname)) + (define sname + (let ((_TAGvector _f32vector) + (TAGvector-length f32vector-length) + (_type _float)) + (get-ffi-obj 'cblas_sname *blas* + (_fun args ...) + (lambda () + (lambda x + (error 'blas (string-append "function " + (symbol->string 'cblas_sname) + " not found in blas library."))))))) + (provide* (unsafe dname)) + (define dname + (let ((_TAGvector _f64vector) + (TAGvector-length f64vector-length) + (_type _double*)) + (get-ffi-obj 'cblas_dname *blas* + (_fun args ...) + (lambda () + (lambda x + (error 'blas (string-append "function " + (symbol->string 'cblas_dname) + " not found in blas library."))))))))))))) + +(define-blas dot _int _TAGvector _int _TAGvector _int -> _type) +(define-blas nrm2 _int _TAGvector _int -> _type) +(define-blas asum _int _TAGvector _int -> _type) +(define-blas swap _int _TAGvector _int _TAGvector _int -> _void) +(define-blas copy _int _TAGvector _int _TAGvector _int -> _void) +(define-blas axpy _int _type _TAGvector _int _TAGvector _int -> _void) +(define-blas scal _int _type _TAGvector _int -> _void) +(define-blas gemv _cblas-order _cblas-transpose _int _int _type _TAGvector _int _TAGvector _int _type _TAGvector _int -> _void) +(define-blas gbmv _cblas-order _cblas-transpose _int _int _int _int _type _TAGvector _int _TAGvector _int _type _TAGvector _int -> _void) +(define-blas trmv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _TAGvector _int _TAGvector _int -> _void) +(define-blas tbmv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _int _TAGvector _int _TAGvector _int -> _void) +(define-blas tpmv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _TAGvector _TAGvector _int -> _void) +(define-blas trsv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _TAGvector _int _TAGvector _int -> _void) +(define-blas tbsv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _int _TAGvector _int _TAGvector _int -> _void) +(define-blas tpsv _cblas-order _cblas-uplo _cblas-transpose _cblas-diag _int _TAGvector _TAGvector _int -> _void) +(define-blas symv _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector _int _type _TAGvector _int -> _void) +(define-blas sbmv _cblas-order _cblas-uplo _int _int _type _TAGvector _int _TAGvector _int _type _TAGvector _int -> _void) +(define-blas spmv _cblas-order _cblas-uplo _int _type _TAGvector _TAGvector _int _type _TAGvector _int -> _void) +(define-blas ger _cblas-order _int _int _type _TAGvector _int _TAGvector _int _TAGvector _int -> _void) +(define-blas syr _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector _int -> _void) +(define-blas spr _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector -> _void) +(define-blas syr2 _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector _int _TAGvector _int -> _void) +(define-blas spr2 _cblas-order _cblas-uplo _int _type _TAGvector _int _TAGvector _int _TAGvector -> _void) +(define-blas gemm _cblas-order _cblas-transpose _cblas-transpose _int _int _type _TAGvector _int + _TAGvector _int _type _TAGvector _int -> _void) +(define-blas symm _cblas-order _cblas-side _cblas-uplo _int _int _type _TAGvector _int + _TAGvector _int _type _TAGvector _int -> _void) +(define-blas syrk _cblas-order _cblas-uplo _cblas-transpose _int _int _type + _TAGvector _int _type _TAGvector _int -> _void) +(define-blas syr2k _cblas-order _cblas-uplo _cblas-transpose _int _int _type _TAGvector + _int _TAGvector _int _type _TAGvector _int -> _void) +(define-blas trmm _cblas-order _cblas-side _cblas-uplo _cblas-transpose _cblas-diag _int _int + _type _TAGvector _int _TAGvector _int -> _void) +(define-blas trsm _cblas-order _cblas-side _cblas-uplo _cblas-transpose _cblas-diag _int _int + _type _TAGvector _int _TAGvector _int -> _void) hunk ./matrix-test.ss 1 +#lang scheme/base + #| matrix-test.ss: Test suite for the matrix.ss module. Copyright (C) 2007 Will M. Farr hunk ./matrix-test.ss 22 |# -(module matrix-test mzscheme - (require (planet "test.ss" ("schematics" "schemeunit.plt" 2)) - (planet "text-ui.ss" ("schematics" "schemeunit.plt" 2)) - (lib "42.ss" "srfi") - (lib "4.ss" "srfi") - "matrix.ss") - - (define-simple-check (check-close? eps a b) - (< (abs (- a b)) (abs eps))) - - (define-simple-check (check-mnorm-close? eps m1 m2) - (< (abs (- (matrix-norm m1) - (matrix-norm m2))) - (abs eps))) - - (provide matrix-test-suite) - - (define matrix-test-suite - (test-suite - "matrix.ss test suite" - (test-case - "basic matrix operations" - (let ((m (matrix-ec 3 3 (:range i 9) (random))) - (i (random 3)) - (j (random 3)) - (elt (random))) - (matrix-set! m i j elt) - (check-equal? (matrix-ref m i j) elt))) - (test-case - "column-major order" - (let ((m (matrix-ec 3 3 (:range i 9) i))) - (check-equal? (matrix-ref m 0 0) 0.0) - (check-equal? (matrix-ref m 1 0) 1.0) - (check-equal? (matrix-ref m 0 1) 3.0)) - (let ((m (matrix 2 2 1 2 3 4))) - (check-equal? (matrix-ref m 1 0) 2.0) - (check-equal? (matrix-ref m 0 1) 3.0))) - (test-case - "add, subtract and scale matrix" - (let ((m1 (matrix-ec 3 3 (:range i 9) (random))) - (m2 (matrix-ec 3 3 (:range i 9) (random)))) - (check-mnorm-close? 1e-10 (matrix-add m1 m1) (matrix-scale m1 2)) - (check-mnorm-close? 1e-10 (matrix-scale m2 2) (matrix-scale m2 2)) - (check-mnorm-close? 1e-10 (matrix-sub m1 m2) (matrix-add m1 (matrix-scale m2 -1))))) - (test-case - "norm, inverse, and matrix-mul" - (let ((m (matrix-ec 11 11 (:range i 121) (random))) - (m2 (matrix-ec 11 11 (:range i 121) (random)))) - (check-close? 1e-10 - (matrix-norm (matrix-sub (matrix-identity 11) - (matrix-mul (matrix-inverse m) m))) - 0) - (check-close? 1e-10 - (matrix-norm (matrix-sub (matrix-identity 11) - (matrix-mul (matrix-inverse m2) m2))) - 0))) - (test-case - "transpose and vector-matrix multiplication" - (let ((m (matrix-ec 10 10 (:range i 100) (random))) - (v (list->f64vector (list-ec (:range i 10) (random))))) - (let ((mv (matrix-f64vector-mul m v)) - (vm (f64vector-matrix-mul v (matrix-transpose m)))) - (do-ec (:range i 10) - (check-close? 1e-10 (f64vector-ref mv i) (f64vector-ref vm i)))))) - (test-case - "matrix-solve" - (let* ((m (matrix-ec 10 10 (:range i 100) (random))) - (b (list->f64vector (list-ec (:range i 10) (random)))) - (x (matrix-solve m b)) - (mx (matrix-f64vector-mul m x))) +(require (planet "test.ss" ("schematics" "schemeunit.plt" 2)) + (planet "text-ui.ss" ("schematics" "schemeunit.plt" 2)) + (lib "42.ss" "srfi") + (lib "4.ss" "srfi") + "matrix.ss") + +(define-simple-check (check-close? eps a b) + (< (abs (- a b)) (abs eps))) + +(define-simple-check (check-mnorm-close? eps m1 m2) + (< (abs (- (matrix-norm m1) + (matrix-norm m2))) + (abs eps))) + +(provide matrix-test-suite) + +(define matrix-test-suite + (test-suite + "matrix.ss test suite" + (test-case + "basic matrix operations" + (let ((m (matrix-ec 3 3 (:range i 9) (random))) + (i (random 3)) + (j (random 3)) + (elt (random))) + (matrix-set! m i j elt) + (check-equal? (matrix-ref m i j) elt))) + (test-case + "column-major order" + (let ((m (matrix-ec 3 3 (:range i 9) i))) + (check-equal? (matrix-ref m 0 0) 0.0) + (check-equal? (matrix-ref m 1 0) 1.0) + (check-equal? (matrix-ref m 0 1) 3.0)) + (let ((m (matrix 2 2 1 2 3 4))) + (check-equal? (matrix-ref m 1 0) 2.0) + (check-equal? (matrix-ref m 0 1) 3.0))) + (test-case + "add, subtract and scale matrix" + (let ((m1 (matrix-ec 3 3 (:range i 9) (random))) + (m2 (matrix-ec 3 3 (:range i 9) (random)))) + (check-mnorm-close? 1e-10 (matrix-add m1 m1) (matrix-scale m1 2)) + (check-mnorm-close? 1e-10 (matrix-scale m2 2) (matrix-scale m2 2)) + (check-mnorm-close? 1e-10 (matrix-sub m1 m2) (matrix-add m1 (matrix-scale m2 -1))))) + (test-case + "norm, inverse, and matrix-mul" + (let ((m (matrix-ec 11 11 (:range i 121) (random))) + (m2 (matrix-ec 11 11 (:range i 121) (random)))) + (check-close? 1e-10 + (matrix-norm (matrix-sub (matrix-identity 11) + (matrix-mul (matrix-inverse m) m))) + 0) + (check-close? 1e-10 + (matrix-norm (matrix-sub (matrix-identity 11) + (matrix-mul (matrix-inverse m2) m2))) + 0))) + (test-case + "transpose and vector-matrix multiplication" + (let ((m (matrix-ec 10 10 (:range i 100) (random))) + (v (list->f64vector (list-ec (:range i 10) (random))))) + (let ((mv (matrix-f64vector-mul m v)) + (vm (f64vector-matrix-mul v (matrix-transpose m)))) (do-ec (:range i 10) hunk ./matrix-test.ss 84 - (check-close? 1e-10 (f64vector-ref b i) (f64vector-ref mx i)))))))) + (check-close? 1e-10 (f64vector-ref mv i) (f64vector-ref vm i)))))) + (test-case + "matrix-solve" + (let* ((m (matrix-ec 10 10 (:range i 100) (random))) + (b (list->f64vector (list-ec (:range i 10) (random)))) + (x (matrix-solve m b)) + (mx (matrix-f64vector-mul m x))) + (do-ec (:range i 10) + (check-close? 1e-10 (f64vector-ref b i) (f64vector-ref mx i))))))) hunk ./matrix.ss 1 +#lang scheme/base + #| matrix.ss: Matrices and matrix operations using BLAS and LAPACK Copyright (C) 2007 Will M. Farr hunk ./matrix.ss 21 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. |# -(module matrix mzscheme - (require (lib "foreign.ss") - (lib "etc.ss") - (all-except (lib "contract.ss") ->) - (rename (lib "contract.ss") ->/c ->) - (all-except (lib "42.ss" "srfi") :) - (all-except (planet "srfi-4-comprehensions.ss" ("wmfarr" "srfi-4-comprehensions.plt")) :) - "blas-lapack.ss") - - (define (list/length/c n) + +(require (lib "foreign.ss") + (lib "etc.ss") + (except-in (lib "contract.ss") ->) + (rename-in (lib "contract.ss") (-> ->/c)) + (except-in (lib "42.ss" "srfi") :) + (except-in (planet "srfi-4-comprehensions.ss" ("wmfarr" "srfi-4-comprehensions.plt")) :) + "blas-lapack.ss") + +(define (list/length/c n) + (flat-named-contract + (format "list of length ~a" n) + (lambda (l) (= (length l) n)))) + +(define (matrix-multiplication-compatible/c m) + (flat-named-contract + (format "compatible for multiplication by a ~a by ~a matrix" (matrix-rows m) (matrix-cols m)) + (lambda (m2) + (= (matrix-cols m) (matrix-rows m2))))) + +(define (matrix-same-dimensions/c m) + (flat-named-contract + (format "~a by ~a matrix" (matrix-rows m) (matrix-cols m)) + (lambda (m2) + (and (= (matrix-rows m) (matrix-rows m2)) + (= (matrix-cols m) (matrix-cols m2)))))) + +(define (matrix-valid-row-index/c m) + (let ((r (matrix-rows m))) (flat-named-contract hunk ./matrix.ss 51 - (format "list of length ~a" n) - (lambda (l) (= (length l) n)))) - - (define (matrix-multiplication-compatible/c m) + (format "valid row index for a ~a by ~a matrix" r (matrix-cols m)) + (lambda (i) (and (>= i 0) + (< i r)))))) + +(define (matrix-valid-col-index/c m) + (let ((c (matrix-cols m))) (flat-named-contract hunk ./matrix.ss 58 - (format "compatible for multiplication by a ~a by ~a matrix" (matrix-rows m) (matrix-cols m)) - (lambda (m2) - (= (matrix-cols m) (matrix-rows m2))))) - - (define (matrix-same-dimensions/c m) + (format "valid column index for ~a by ~a matrix" (matrix-rows m) c) + (lambda (j) (and (>= j 0) + (< j c)))))) + +(define (matrix-col-vector-compatible/c m) + (let ((c (matrix-cols m))) (flat-named-contract hunk ./matrix.ss 65 - (format "~a by ~a matrix" (matrix-rows m) (matrix-cols m)) - (lambda (m2) - (and (= (matrix-rows m) (matrix-rows m2)) - (= (matrix-cols m) (matrix-cols m2)))))) - - (define (matrix-valid-row-index/c m) - (let ((r (matrix-rows m))) - (flat-named-contract - (format "valid row index for a ~a by ~a matrix" r (matrix-cols m)) - (lambda (i) (and (>= i 0) - (< i r)))))) - - (define (matrix-valid-col-index/c m) - (let ((c (matrix-cols m))) - (flat-named-contract - (format "valid column index for ~a by ~a matrix" (matrix-rows m) c) - (lambda (j) (and (>= j 0) - (< j c)))))) - - (define (matrix-col-vector-compatible/c m) - (let ((c (matrix-cols m))) - (flat-named-contract - (format "column vector of length ~a" c) - (lambda (v) (= (f64vector-length v) c))))) - - (define (matrix-row-vector-compatible/c m) - (let ((r (matrix-rows m))) - (flat-named-contract - (format "row vector of length ~a" r) - (lambda (v) (= (f64vector-length v) r))))) - - (define matrix-square/c + (format "column vector of length ~a" c) + (lambda (v) (= (f64vector-length v) c))))) + +(define (matrix-row-vector-compatible/c m) + (let ((r (matrix-rows m))) (flat-named-contract hunk ./matrix.ss 71 - "square matrix" - (lambda (m) (= (matrix-rows m) (matrix-cols m))))) - - (define-struct matrix - (ptr rows cols) #f) - - (provide matrix? matrix-multiplication-compatible/c matrix-valid-row-index/c - matrix-valid-col-index/c matrix-square/c matrix-same-dimensions/c - matrix-col-vector-compatible/c matrix-row-vector-compatible/c - matrix-ec :matrix - _matrix - struct:matrix) - - (provide/contract - (rename my-make-matrix make-matrix - (->/c natural-number/c natural-number/c number? matrix?)) - (rename my-matrix matrix - (->r ((i natural-number/c) - (j natural-number/c)) - elts (and/c (listof number?) - (list/length/c (* i j))) - matrix?)) - (matrix-rows (->/c matrix? natural-number/c)) - (matrix-cols (->/c matrix? natural-number/c)) - (matrix-ref (->r ((m matrix?) - (i (and/c natural-number/c - (matrix-valid-row-index/c m))) - (j (and/c natural-number/c - (matrix-valid-col-index/c m)))) - number?)) - (matrix-set! (->r ((m matrix?) - (i (and/c natural-number/c - (matrix-valid-row-index/c m))) - (j (and/c natural-number/c - (matrix-valid-col-index/c m))) - (x number?)) - any)) - (matrix-add (->r ((m1 matrix?) - (m2 (and/c matrix? - (matrix-same-dimensions/c m1)))) - matrix?)) - (matrix-sub (->r ((m1 matrix?) - (m2 (and/c matrix? - (matrix-same-dimensions/c m1)))) - matrix?)) - (matrix-scale (->/c matrix? number? matrix?)) - (matrix-mul (->r ((m1 matrix?) - (m2 (and/c matrix? - (matrix-multiplication-compatible/c m1)))) - matrix?)) - (matrix-f64vector-mul (->r ((m matrix?) - (v (and/c f64vector? - (matrix-col-vector-compatible/c m)))) - f64vector?)) - (f64vector-matrix-mul (->r ((v (and/c f64vector? - (matrix-row-vector-compatible/c m))) - (m matrix?)) - f64vector?)) - (matrix-inverse (->/c (and/c matrix? matrix-square/c) matrix?)) - (matrix-norm (->/c matrix? number?)) - (matrix-identity (->/c natural-number/c matrix?)) - (matrix-transpose (->/c matrix? matrix?)) - (matrix-solve (->r ((m matrix-square/c) - (v (and/c f64vector? - (let ((r (matrix-rows m))) - (flat-named-contract - (format "column vector of length ~a" r) - (lambda (v) (= (f64vector-length v) r))))))) - f64vector?)) - (matrix-solve-many (->r ((m1 matrix-square/c) - (m2 (and/c matrix? - (matrix-multiplication-compatible/c m1)))) - matrix?))) - - (unsafe!) - - (define my-make-matrix - (case-lambda - ((rows cols) - (let* ((n (* rows cols)) - (p (malloc n _double 'atomic))) - (memset p 0 n _double) - (make-matrix p rows cols))) - ((rows cols elt) - (let* ((m (my-make-matrix rows cols)) - (p (matrix-ptr m))) - (do-ec (:range i (* rows cols)) - (ptr-set! p _double* i elt)) - m)))) - - (define (my-matrix i j . elts) - (let* ((m (my-make-matrix i j)) - (p (matrix-ptr m))) - (do-ec (:parallel (:range k (* i j)) - (:list elt elts)) - (ptr-set! p _double* k elt)) - m)) - - (define _matrix* - (make-ctype _pointer matrix-ptr - (lambda (x) - (error '_matrix - "cannot convert C output to _matrix")))) - - (define-fun-syntax _matrix - (syntax-id-rules (i o io) - ((_matrix i) - _matrix*) - ((_matrix o rows cols) - (type: _pointer - pre: (let* ((n (* rows cols)) - (p (malloc n _double 'atomic))) - (memset p 0 n _double) - p) - post: (p => (make-matrix p rows cols)))) - ((_matrix io) - (type: _pointer - bind: m - pre: (m => (matrix-ptr m)) - post: m)) - (_matrix _matrix*))) - - (define (matrix-ptr-index m i j) - (+ i (* j (matrix-rows m)))) - - (define (matrix-ref m i j) - (ptr-ref (matrix-ptr m) _double (matrix-ptr-index m i j))) - - (define (matrix-set! m i j elt) - (ptr-set! (matrix-ptr m) _double* (matrix-ptr-index m i j) elt)) - - (define (matrix-length m) - (* (matrix-cols m) - (matrix-rows m))) - - (define matrix-copy - (get-ffi-obj 'cblas_dcopy *blas* - (_fun (m) :: - (_int = (matrix-length m)) - (m : _matrix) (_int = 1) - (m-out : (_matrix o (matrix-rows m) (matrix-cols m))) (_int = 1) -> - _void -> - m-out))) - - (define matrix-add - (get-ffi-obj 'cblas_daxpy *blas* - (_fun (m1 m2) :: - (_int = (matrix-length m1)) - (_double = 1.0) - (_matrix = m1) (_int = 1) - (m-out : _matrix = (matrix-copy m2)) (_int = 1) -> - _void -> - m-out))) - - (define matrix-sub - (get-ffi-obj 'cblas_daxpy *blas* - (_fun (m1 m2) :: - (_int = (matrix-length m1)) - (_double = -1.0) - (_matrix = m2) (_int = 1) - (m-out : _matrix = (matrix-copy m1)) (_int = 1) -> - _void -> - m-out))) - - (define matrix-scale - (get-ffi-obj 'cblas_dscal *blas* - (_fun (m s) :: - (_int = (matrix-length m)) - (_double* = s) - (m-out : _matrix = (matrix-copy m)) (_int = 1) -> - _void -> - m-out))) - - (define matrix-mul - (get-ffi-obj 'cblas_dgemm *blas* - (_fun (m1 m2) :: - (_cblas-order = 'col-major) - (_cblas-transpose = 'no-trans) - (_cblas-transpose = 'no-trans) - (_int = (matrix-rows m1)) - (_int = (matrix-cols m2)) - (_int = (matrix-rows m2)) - (_double = 1.0) - (_matrix = m1) - (_int = (matrix-rows m1)) - (_matrix = m2) - (_int = (matrix-rows m2)) - (_double = 0.0) - (m-out : (_matrix o (matrix-rows m1) (matrix-cols m2))) - (_int = (matrix-rows m1)) -> - _void -> - m-out))) - - (define matrix-f64vector-mul - (get-ffi-obj 'cblas_dgemv *blas* - (_fun (m v) :: - (_cblas-order = 'col-major) - (_cblas-transpose = 'no-trans) - (_int = (matrix-rows m)) - (_int = (matrix-cols m)) - (_double = 1.0) - (_matrix = m) - (_int = (matrix-rows m)) - (_f64vector = v) - (_int = 1) - (_double = 0.0) - (v-out : (_f64vector o (matrix-rows m))) - (_int = 1) -> - _void -> - v-out))) - - (define f64vector-matrix-mul - (get-ffi-obj 'cblas_dgemv *blas* - (_fun (v m) :: - (_cblas-order = 'col-major) - (_cblas-transpose = 'trans) - (_int = (matrix-rows m)) - (_int = (matrix-cols m)) - (_double = 1.0) - (_matrix = m) - (_int = (matrix-rows m)) - (_f64vector = v) - (_int = 1) - (_double = 0.0) - (v-out : (_f64vector o (matrix-cols m))) - (_int = 1) -> - _void -> - v-out))) - - (define matrix-lu-decomp - (get-ffi-obj 'dgetrf_ *lapack* - (_fun (m) :: - ((_ptr i _int) = (matrix-rows m)) - ((_ptr i _int) = (matrix-cols m)) - (m-out : _matrix = (matrix-copy m)) - ((_ptr i _int) = (matrix-rows m)) - (ipiv : (_u32vector o (matrix-rows m))) - (_ptr o _int) -> - _void -> - (values m-out ipiv)))) - - (define dgetri-lwork - (get-ffi-obj 'dgetri_ *lapack* - (_fun (m ipiv) :: - (n : (_ptr i _int) = (matrix-rows m)) - (_matrix = m) - ((_ptr i _int) = n) - (_u32vector = ipiv) - (lwork : (_ptr o _double)) - ((_ptr i _int) = -1) - (res : (_ptr o _int)) -> - _void -> - (values (inexact->exact (round lwork)) res)))) - - (define dgetri/lwork - (get-ffi-obj 'dgetri_ *lapack* - (_fun (m ipiv lwork) :: - (n : (_ptr i _int) = (matrix-rows m)) - (m-out : _matrix = (matrix-copy m)) - ((_ptr i _int) = n) - (_u32vector = ipiv) - (_f64vector o lwork) - ((_ptr i _int) = lwork) - (_ptr o _int) -> - _void -> - m-out))) - - (define (matrix-inverse m) - (let-values (((m-lu ipiv) - (matrix-lu-decomp m))) - (let-values (((lwork res) - (dgetri-lwork m-lu ipiv))) - (dgetri/lwork m-lu ipiv lwork)))) - - (define matrix-norm - (get-ffi-obj 'cblas_dnrm2 *blas* - (_fun (m) :: - (_int = (matrix-length m)) - (_matrix = m) - (_int = 1) -> - _double))) - - (define (matrix-transpose m) - (let ((r (matrix-rows m)) - (c (matrix-cols m))) - (matrix-ec c r (:range j r) (:range i c) (matrix-ref m j i)))) - - (define matrix-solve-many - (get-ffi-obj 'dgesv_ *lapack* - (_fun (m b) :: - ((_ptr i _int) = (matrix-rows m)) - ((_ptr i _int) = (matrix-cols b)) - (_matrix = (matrix-copy m)) - ((_ptr i _int) = (matrix-rows m)) - (_u32vector o (matrix-rows m)) - (x : _matrix = (matrix-copy b)) - ((_ptr i _int) = (matrix-rows b)) - (_ptr o _int) -> - _void -> - x))) - - (define matrix-solve - (get-ffi-obj 'dgesv_ *lapack* - (_fun (m v) :: - ((_ptr i _int) = (matrix-rows m)) - ((_ptr i _int) = 1) - (_matrix = (matrix-copy m)) - ((_ptr i _int) = (matrix-rows m)) - (_u32vector o (matrix-rows m)) - (x : _f64vector = (f64vector-of-length-ec (f64vector-length v) (:f64vector x v) x)) - ((_ptr i _int) = (f64vector-length v)) - (_ptr o _int) -> - _void -> - x))) - - (define (matrix-identity n) - (let ((m (my-make-matrix n n 0.0))) - (do-ec (:range i n) (matrix-set! m i i 1.0)) - m)) - - (define-syntax matrix-ec - (syntax-rules () - ((matrix-ec rrows ccols etc ...) - (let ((rows rrows) - (cols ccols)) - (apply my-matrix rows cols (list-ec etc ...)))))) - - (define-syntax :matrix - (syntax-rules (index) - ((:matrix cc var arg) - (:matrix cc var (index i j) arg)) - ((:matrix cc var (index i j) arg) - (:do cc - (let ((m arg) - (rows #f) - (cols #f)) - (set! rows (matrix-rows m)) - (set! cols (matrix-cols m))) - ((i 0) (j 0)) - (< j cols) - (let ((i+1 (+ i 1)) - (j+1 (+ j 1)) - (wrapping? #f) - (var (matrix-ref m i j))) - (set! wrapping? (>= i+1 rows))) - #t - ((if wrapping? 0 i+1) - (if wrapping? j+1 j)))))) - - (define ptr->matrix make-matrix) - (provide* (unsafe ptr->matrix)) - (define-unsafer matrix-unsafe!)) + (format "row vector of length ~a" r) + (lambda (v) (= (f64vector-length v) r))))) + +(define matrix-square/c + (flat-named-contract + "square matrix" + (lambda (m) (= (matrix-rows m) (matrix-cols m))))) + +(define-struct matrix + (ptr rows cols) #f) + +(provide matrix? matrix-multiplication-compatible/c matrix-valid-row-index/c + matrix-valid-col-index/c matrix-square/c matrix-same-dimensions/c + matrix-col-vector-compatible/c matrix-row-vector-compatible/c + matrix-ec :matrix + _matrix + struct:matrix + exn:singular-matrix?) + +(provide/contract + (rename my-make-matrix make-matrix + (->/c natural-number/c natural-number/c number? matrix?)) + (rename my-matrix matrix + (->r ((i natural-number/c) + (j natural-number/c)) + elts (and/c (listof number?) + (list/length/c (* i j))) + matrix?)) + (matrix-rows (->/c matrix? natural-number/c)) + (matrix-cols (->/c matrix? natural-number/c)) + (matrix-ref (->r ((m matrix?) + (i (and/c natural-number/c + (matrix-valid-row-index/c m))) + (j (and/c natural-number/c + (matrix-valid-col-index/c m)))) + number?)) + (matrix-set! (->r ((m matrix?) + (i (and/c natural-number/c + (matrix-valid-row-index/c m))) + (j (and/c natural-number/c + (matrix-valid-col-index/c m))) + (x number?)) + any)) + (matrix-add (->r ((m1 matrix?) + (m2 (and/c matrix? + (matrix-same-dimensions/c m1)))) + matrix?)) + (matrix-sub (->r ((m1 matrix?) + (m2 (and/c matrix? + (matrix-same-dimensions/c m1)))) + matrix?)) + (matrix-scale (->/c matrix? number? matrix?)) + (matrix-mul (->r ((m1 matrix?) + (m2 (and/c matrix? + (matrix-multiplication-compatible/c m1)))) + matrix?)) + (matrix-f64vector-mul (->r ((m matrix?) + (v (and/c f64vector? + (matrix-col-vector-compatible/c m)))) + f64vector?)) + (f64vector-matrix-mul (->r ((v (and/c f64vector? + (matrix-row-vector-compatible/c m))) + (m matrix?)) + f64vector?)) + (matrix-inverse (->/c (and/c matrix? matrix-square/c) matrix?)) + (matrix-norm (->/c matrix? number?)) + (matrix-identity (->/c natural-number/c matrix?)) + (matrix-transpose (->/c matrix? matrix?)) + (matrix-solve (->r ((m matrix-square/c) + (v (and/c f64vector? + (let ((r (matrix-rows m))) + (flat-named-contract + (format "column vector of length ~a" r) + (lambda (v) (= (f64vector-length v) r))))))) + f64vector?)) + (matrix-solve-many (->r ((m1 matrix-square/c) + (m2 (and/c matrix? + (matrix-multiplication-compatible/c m1)))) + matrix?))) + +(unsafe!) + +(define-struct (exn:singular-matrix exn) + (elt) #:inspector #f) + +(define my-make-matrix + (case-lambda + ((rows cols) + (let* ((n (* rows cols)) + (p (malloc n _double 'atomic))) + (memset p 0 n _double) + (make-matrix p rows cols))) + ((rows cols elt) + (let* ((m (my-make-matrix rows cols)) + (p (matrix-ptr m))) + (do-ec (:range i (* rows cols)) + (ptr-set! p _double* i elt)) + m)))) + +(define (my-matrix i j . elts) + (let* ((m (my-make-matrix i j)) + (p (matrix-ptr m))) + (do-ec (:parallel (:range k (* i j)) + (:list elt elts)) + (ptr-set! p _double* k elt)) + m)) + +(define _matrix* + (make-ctype _pointer matrix-ptr + (lambda (x) + (error '_matrix + "cannot convert C output to _matrix")))) + +(define-fun-syntax _matrix + (syntax-id-rules (i o io) + ((_matrix i) + _matrix*) + ((_matrix o rows cols) + (type: _pointer + pre: (let* ((n (* rows cols)) + (p (malloc n _double 'atomic))) + (memset p 0 n _double) + p) + post: (p => (make-matrix p rows cols)))) + ((_matrix io) + (type: _pointer + bind: m + pre: (m => (matrix-ptr m)) + post: m)) + (_matrix _matrix*))) + +(define (matrix-ptr-index m i j) + (+ i (* j (matrix-rows m)))) + +(define (matrix-ref m i j) + (ptr-ref (matrix-ptr m) _double (matrix-ptr-index m i j))) + +(define (matrix-set! m i j elt) + (ptr-set! (matrix-ptr m) _double* (matrix-ptr-index m i j) elt)) + +(define (matrix-length m) + (* (matrix-cols m) + (matrix-rows m))) + +(define matrix-copy + (get-ffi-obj 'cblas_dcopy *blas* + (_fun (m) :: + (_int = (matrix-length m)) + (m : _matrix) (_int = 1) + (m-out : (_matrix o (matrix-rows m) (matrix-cols m))) (_int = 1) -> + _void -> + m-out))) + +(define matrix-add + (get-ffi-obj 'cblas_daxpy *blas* + (_fun (m1 m2) :: + (_int = (matrix-length m1)) + (_double = 1.0) + (_matrix = m1) (_int = 1) + (m-out : _matrix = (matrix-copy m2)) (_int = 1) -> + _void -> + m-out))) + +(define matrix-sub + (get-ffi-obj 'cblas_daxpy *blas* + (_fun (m1 m2) :: + (_int = (matrix-length m1)) + (_double = -1.0) + (_matrix = m2) (_int = 1) + (m-out : _matrix = (matrix-copy m1)) (_int = 1) -> + _void -> + m-out))) + +(define matrix-scale + (get-ffi-obj 'cblas_dscal *blas* + (_fun (m s) :: + (_int = (matrix-length m)) + (_double* = s) + (m-out : _matrix = (matrix-copy m)) (_int = 1) -> + _void -> + m-out))) + +(define matrix-mul + (get-ffi-obj 'cblas_dgemm *blas* + (_fun (m1 m2) :: + (_cblas-order = 'col-major) + (_cblas-transpose = 'no-trans) + (_cblas-transpose = 'no-trans) + (_int = (matrix-rows m1)) + (_int = (matrix-cols m2)) + (_int = (matrix-rows m2)) + (_double = 1.0) + (_matrix = m1) + (_int = (matrix-rows m1)) + (_matrix = m2) + (_int = (matrix-rows m2)) + (_double = 0.0) + (m-out : (_matrix o (matrix-rows m1) (matrix-cols m2))) + (_int = (matrix-rows m1)) -> + _void -> + m-out))) + +(define matrix-f64vector-mul + (get-ffi-obj 'cblas_dgemv *blas* + (_fun (m v) :: + (_cblas-order = 'col-major) + (_cblas-transpose = 'no-trans) + (_int = (matrix-rows m)) + (_int = (matrix-cols m)) + (_double = 1.0) + (_matrix = m) + (_int = (matrix-rows m)) + (_f64vector = v) + (_int = 1) + (_double = 0.0) + (v-out : (_f64vector o (matrix-rows m))) + (_int = 1) -> + _void -> + v-out))) + +(define f64vector-matrix-mul + (get-ffi-obj 'cblas_dgemv *blas* + (_fun (v m) :: + (_cblas-order = 'col-major) + (_cblas-transpose = 'trans) + (_int = (matrix-rows m)) + (_int = (matrix-cols m)) + (_double = 1.0) + (_matrix = m) + (_int = (matrix-rows m)) + (_f64vector = v) + (_int = 1) + (_double = 0.0) + (v-out : (_f64vector o (matrix-cols m))) + (_int = 1) -> + _void -> + v-out))) + +(define matrix-lu-decomp + (get-ffi-obj 'dgetrf_ *lapack* + (_fun (m) :: + ((_ptr i _int) = (matrix-rows m)) + ((_ptr i _int) = (matrix-cols m)) + (m-out : _matrix = (matrix-copy m)) + ((_ptr i _int) = (matrix-rows m)) + (ipiv : (_u32vector o (matrix-rows m))) + (_ptr o _int) -> + _void -> + (values m-out ipiv)))) + +(define dgetri-lwork + (get-ffi-obj 'dgetri_ *lapack* + (_fun (m ipiv) :: + (n : (_ptr i _int) = (matrix-rows m)) + (_matrix = m) + ((_ptr i _int) = n) + (_u32vector = ipiv) + (lwork : (_ptr o _double)) + ((_ptr i _int) = -1) + (res : (_ptr o _int)) -> + _void -> + (values (inexact->exact (round lwork)) res)))) + +(define dgetri/lwork + (get-ffi-obj 'dgetri_ *lapack* + (_fun (m ipiv lwork) :: + (n : (_ptr i _int) = (matrix-rows m)) + (m-out : _matrix = (matrix-copy m)) + ((_ptr i _int) = n) + (_u32vector = ipiv) + (_f64vector o lwork) + ((_ptr i _int) = lwork) + (_ptr o _int) -> + _void -> + m-out))) + +(define (matrix-inverse m) + (let-values (((m-lu ipiv) + (matrix-lu-decomp m))) + (let-values (((lwork res) + (dgetri-lwork m-lu ipiv))) + (dgetri/lwork m-lu ipiv lwork)))) + +(define matrix-norm + (get-ffi-obj 'cblas_dnrm2 *blas* + (_fun (m) :: + (_int = (matrix-length m)) + (_matrix = m) + (_int = 1) -> + _double))) + +(define (matrix-transpose m) + (let ((r (matrix-rows m)) + (c (matrix-cols m))) + (matrix-ec c r (:range j r) (:range i c) (matrix-ref m j i)))) + +(define matrix-solve-many + (get-ffi-obj 'dgesv_ *lapack* + (_fun (m b) :: + ((_ptr i _int) = (matrix-rows m)) + ((_ptr i _int) = (matrix-cols b)) + (_matrix = (matrix-copy m)) + ((_ptr i _int) = (matrix-rows m)) + (_u32vector o (matrix-rows m)) + (x : _matrix = (matrix-copy b)) + ((_ptr i _int) = (matrix-rows b)) + (info : (_ptr o _int)) -> + _void -> + (if (> info 0) + (raise (make-exn:singular-matrix "singular matrix in dgesv" + (current-continuation-marks) + info)) + x)))) + +(define matrix-solve + (get-ffi-obj 'dgesv_ *lapack* + (_fun (m v) :: + ((_ptr i _int) = (matrix-rows m)) + ((_ptr i _int) = 1) + (_matrix = (matrix-copy m)) + ((_ptr i _int) = (matrix-rows m)) + (_u32vector o (matrix-rows m)) + (x : _f64vector = (f64vector-of-length-ec (f64vector-length v) (:f64vector x v) x)) + ((_ptr i _int) = (f64vector-length v)) + (info : (_ptr o _int)) -> + _void -> + (if (> info 0) + (raise (make-exn:singular-matrix "singular matrix in dgesv" + (current-continuation-marks) + info)) + x)))) + +(define (matrix-identity n) + (let ((m (my-make-matrix n n 0.0))) + (do-ec (:range i n) (matrix-set! m i i 1.0)) + m)) + +(define-syntax matrix-ec + (syntax-rules () + ((matrix-ec rrows ccols etc ...) + (let ((rows rrows) + (cols ccols)) + (apply my-matrix rows cols (list-ec etc ...)))))) + +(define-syntax :matrix + (syntax-rules (index) + ((:matrix cc var arg) + (:matrix cc var (index i j) arg)) + ((:matrix cc var (index i j) arg) + (:do cc + (let ((m arg) + (rows #f) + (cols #f)) + (set! rows (matrix-rows m)) + (set! cols (matrix-cols m))) + ((i 0) (j 0)) + (< j cols) + (let ((i+1 (+ i 1)) + (j+1 (+ j 1)) + (wrapping? #f) + (var (matrix-ref m i j))) + (set! wrapping? (>= i+1 rows))) + #t + ((if wrapping? 0 i+1) + (if wrapping? j+1 j)))))) + +(define ptr->matrix make-matrix) +(provide* (unsafe ptr->matrix)) +(define-unsafer matrix-unsafe!) hunk ./run-tests.ss 1 +#lang scheme/base + #| run-tests.ss: Run all tests for the plt-linalg package. Copyright (C) 2007 Will M. Farr hunk ./run-tests.ss 22 |# -(module run-tests mzscheme - (require (planet "text-ui.ss" ("schematics" "schemeunit.plt" 2)) - (planet "test.ss" ("schematics" "schemeunit.plt" 2)) - "matrix-test.ss" "vector-test.ss") - - (define all-tests (test-suite - "all tests" - matrix-test-suite vector-test-suite)) - - (test/text-ui all-tests 'verbose)) +(require (planet "text-ui.ss" ("schematics" "schemeunit.plt" 2)) + (planet "test.ss" ("schematics" "schemeunit.plt" 2)) + "matrix-test.ss" "vector-test.ss") + +(define all-tests (test-suite + "all tests" + matrix-test-suite vector-test-suite)) + +(test/text-ui all-tests 'verbose) hunk ./vector-test.ss 1 +#lang scheme/base + #| vector-test.ss: Test suite for vector.ss Copyright (C) 2007 Will M. Farr hunk ./vector-test.ss 21 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. |# -(module vector-test mzscheme - (require (planet "test.ss" ("schematics" "schemeunit.plt" 2)) - "vector.ss" - (lib "math.ss") - (planet "srfi-4-comprehensions.ss" ("wmfarr" "srfi-4-comprehensions.plt" 1))) - - (provide vector-test-suite) - - (define-simple-check (check-close? eps a b) - (< (abs (- a b)) (abs eps))) - - (define vector-test-suite - (test-suite - "vector.ss test suite" - (test-case - "f64vector-norm" - (check-close? 1e-10 - (f64vector-norm (f64vector 1.0 2.0 3.0 4.0)) - (sqrt 30))) - (test-case - "f64vector-add, f64vector-sub, f64vector-scale" - (let ((v1 (f64vector-of-length-ec 10 (:range i 10) (random))) - (v2 (f64vector-of-length-ec 10 (:range i 10) (random)))) - (do-ec (:parallel (:f64vector x (f64vector-sub v1 v2)) - (:f64vector y (f64vector-add v1 (f64vector-scale v2 -1)))) - (check-close? 1e-10 x y)))) - (test-case - "f64vector-dot" - (let ((v (f64vector-of-length-ec 10 (:range i 10) (random)))) - (check-close? 1e-10 (f64vector-dot v v) (sqr (f64vector-norm v)))))))) +(require (planet "test.ss" ("schematics" "schemeunit.plt" 2)) + "vector.ss" + (lib "math.ss") + (planet "srfi-4-comprehensions.ss" ("wmfarr" "srfi-4-comprehensions.plt" 1))) + +(provide vector-test-suite) + +(define-simple-check (check-close? eps a b) + (< (abs (- a b)) (abs eps))) + +(define vector-test-suite + (test-suite + "vector.ss test suite" + (test-case + "f64vector-norm" + (check-close? 1e-10 + (f64vector-norm (f64vector 1.0 2.0 3.0 4.0)) + (sqrt 30))) + (test-case + "f64vector-add, f64vector-sub, f64vector-scale" + (let ((v1 (f64vector-of-length-ec 10 (:range i 10) (random))) + (v2 (f64vector-of-length-ec 10 (:range i 10) (random)))) + (do-ec (:parallel (:f64vector x (f64vector-sub v1 v2)) + (:f64vector y (f64vector-add v1 (f64vector-scale v2 -1)))) + (check-close? 1e-10 x y)))) + (test-case + "f64vector-dot" + (let ((v (f64vector-of-length-ec 10 (:range i 10) (random)))) + (check-close? 1e-10 (f64vector-dot v v) (sqr (f64vector-norm v))))))) hunk ./vector.ss 1 +#lang scheme/base + #| vector.ss: f64vectors in linear algebra. Copyright (C) 2007 Will M. Farr hunk ./vector.ss 21 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. |# -(module vector mzscheme - (require (lib "foreign.ss") - (all-except (lib "contract.ss") ->) - "blas-lapack.ss" - (rename (lib "contract.ss") ->/c ->) - (lib "4.ss" "srfi")) - - (unsafe!) - - (provide (all-from (lib "4.ss" "srfi")) f64vector-same-length/c) - - (define (f64vector-same-length/c v1) - (let ((n (f64vector-length v1))) - (flat-named-contract - (format "length ~a f64vector" (f64vector-length v1)) - (lambda (v2) (= (f64vector-length v2) n))))) - - (provide/contract - (f64vector-norm (->/c f64vector? (>=/c 0))) - (f64vector-copy (->/c f64vector? f64vector?)) - (f64vector-scale (->/c f64vector? number? f64vector?)) - (f64vector-add (->r ((v1 f64vector?) - (v2 (and/c f64vector? - (f64vector-same-length/c v1)))) - f64vector?)) - (f64vector-sub (->r ((v1 f64vector?) - (v2 (and/c f64vector? - (f64vector-same-length/c v1)))) - f64vector?)) - (f64vector-dot (->/c f64vector? f64vector? number?))) - - (define f64vector-norm - (get-ffi-obj 'cblas_dnrm2 *blas* - (_fun (v) :: - (_int = (f64vector-length v)) - (_f64vector = v) - (_int = 1) -> - _double))) - - (define f64vector-copy - (get-ffi-obj 'cblas_dcopy *blas* - (_fun (v) :: - (_int = (f64vector-length v)) - (_f64vector = v) - (_int = 1) - (v-out : (_f64vector o (f64vector-length v))) - (_int = 1) -> - _void -> - v-out))) - - (define f64vector-scale - (get-ffi-obj 'cblas_dscal *blas* - (_fun (v s) :: - (_int = (f64vector-length v)) - (_double* = s) - (v-out : _f64vector = (f64vector-copy v)) - (_int = 1) -> - _void -> - v-out))) - - (define f64vector-add - (get-ffi-obj 'cblas_daxpy *blas* - (_fun (v1 v2) :: - (_int = (f64vector-length v1)) - (_double = 1.0) - (_f64vector = v1) - (_int = 1) - (v-out : _f64vector = (f64vector-copy v2)) - (_int = 1) -> - _void -> - v-out))) - - (define f64vector-sub - (get-ffi-obj 'cblas_daxpy *blas* - (_fun (v1 v2) :: - (_int = (f64vector-length v1)) - (_double = -1.0) - (_f64vector = v2) - (_int = 1) - (v-out : _f64vector = (f64vector-copy v1)) - (_int = 1) -> - _void -> - v-out))) - - (define f64vector-dot - (get-ffi-obj 'cblas_ddot *blas* - (_fun (v1 v2) :: - (_int = (f64vector-length v1)) - (_f64vector = v1) - (_int = 1) - (_f64vector = v2) - (_int = 1) -> - _double)))) +(require (lib "foreign.ss") + (except-in (lib "contract.ss") ->) + "blas-lapack.ss" + (rename-in (lib "contract.ss") (->/ c->)) + (lib "4.ss" "srfi")) + +(unsafe!) + +(provide (all-from (lib "4.ss" "srfi")) f64vector-same-length/c) + +(define (f64vector-same-length/c v1) + (let ((n (f64vector-length v1))) + (flat-named-contract + (format "length ~a f64vector" (f64vector-length v1)) + (lambda (v2) (= (f64vector-length v2) n))))) + +(provide/contract + (f64vector-norm (->/c f64vector? (>=/c 0))) + (f64vector-copy (->/c f64vector? f64vector?)) + (f64vector-scale (->/c f64vector? number? f64vector?)) + (f64vector-add (->r ((v1 f64vector?) + (v2 (and/c f64vector? + (f64vector-same-length/c v1)))) + f64vector?)) + (f64vector-sub (->r ((v1 f64vector?) + (v2 (and/c f64vector? + (f64vector-same-length/c v1)))) + f64vector?)) + (f64vector-dot (->/c f64vector? f64vector? number?))) + +(define f64vector-norm + (get-ffi-obj 'cblas_dnrm2 *blas* + (_fun (v) :: + (_int = (f64vector-length v)) + (_f64vector = v) + (_int = 1) -> + _double))) + +(define f64vector-copy + (get-ffi-obj 'cblas_dcopy *blas* + (_fun (v) :: + (_int = (f64vector-length v)) + (_f64vector = v) + (_int = 1) + (v-out : (_f64vector o (f64vector-length v))) + (_int = 1) -> + _void -> + v-out))) + +(define f64vector-scale + (get-ffi-obj 'cblas_dscal *blas* + (_fun (v s) :: + (_int = (f64vector-length v)) + (_double* = s) + (v-out : _f64vector = (f64vector-copy v)) + (_int = 1) -> + _void -> + v-out))) + +(define f64vector-add + (get-ffi-obj 'cblas_daxpy *blas* + (_fun (v1 v2) :: + (_int = (f64vector-length v1)) + (_double = 1.0) + (_f64vector = v1) + (_int = 1) + (v-out : _f64vector = (f64vector-copy v2)) + (_int = 1) -> + _void -> + v-out))) + +(define f64vector-sub + (get-ffi-obj 'cblas_daxpy *blas* + (_fun (v1 v2) :: + (_int = (f64vector-length v1)) + (_double = -1.0) + (_f64vector = v2) + (_int = 1) + (v-out : _f64vector = (f64vector-copy v1)) + (_int = 1) -> + _void -> + v-out))) + +(define f64vector-dot + (get-ffi-obj 'cblas_ddot *blas* + (_fun (v1 v2) :: + (_int = (f64vector-length v1)) + (_f64vector = v1) + (_int = 1) + (_f64vector = v2) + (_int = 1) -> + _double))) } Context: [TAG 1.4 Will M. Farr **20070525154423] Patch bundle hash: 6457d25e112f0453f5d2e16f5dbbf705964d629e