vector.ss
#|  vector.ss: f64vectors in linear algebra.
    Copyright (C) 2007 Will M. Farr <farr@mit.edu>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along
    with this program; if not, write to the Free Software Foundation, Inc.,
    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|#

(module vector mzscheme
  (require (lib "foreign.ss")
           (all-except (lib "contract.ss") ->)
           "blas-lapack.ss"
           (rename (lib "contract.ss") ->/c ->)
           (lib "4.ss" "srfi"))
  
  (unsafe!)
  
  (provide (all-from (lib "4.ss" "srfi")) f64vector-same-length/c)
  
  (define (f64vector-same-length/c v1)
    (let ((n (f64vector-length v1)))
      (flat-named-contract
       (format "length ~a f64vector" (f64vector-length v1))
       (lambda (v2) (= (f64vector-length v2) n)))))
  
  (provide/contract
   (f64vector-norm (->/c f64vector? (>=/c 0)))
   (f64vector-copy (->/c f64vector? f64vector?))
   (f64vector-scale (->/c f64vector? number? f64vector?))
   (f64vector-add (->r ((v1 f64vector?)
                        (v2 (and/c f64vector?
                                   (f64vector-same-length/c v1))))
                       f64vector?))
   (f64vector-sub (->r ((v1 f64vector?)
                        (v2 (and/c f64vector?
                                   (f64vector-same-length/c v1))))
                       f64vector?))
   (f64vector-dot (->/c f64vector? f64vector? number?)))
  
  (define f64vector-norm
    (get-ffi-obj 'cblas_dnrm2 *blas*
                 (_fun (v) ::
                       (_int = (f64vector-length v))
                       (_f64vector = v)
                       (_int = 1) ->
                       _double)))
  
  (define f64vector-copy
    (get-ffi-obj 'cblas_dcopy *blas*
                 (_fun (v) ::
                       (_int = (f64vector-length v))
                       (_f64vector = v)
                       (_int = 1)
                       (v-out : (_f64vector o (f64vector-length v)))
                       (_int = 1) ->
                       _void ->
                       v-out)))
  
  (define f64vector-scale
    (get-ffi-obj 'cblas_dscal *blas*
                 (_fun (v s) ::
                       (_int = (f64vector-length v))
                       (_double* = s)
                       (v-out : _f64vector = (f64vector-copy v))
                       (_int = 1) ->
                       _void ->
                       v-out)))
  
  (define f64vector-add
    (get-ffi-obj 'cblas_daxpy *blas*
                 (_fun (v1 v2) ::
                       (_int = (f64vector-length v1))
                       (_double = 1.0)
                       (_f64vector = v1)
                       (_int = 1)
                       (v-out : _f64vector = (f64vector-copy v2))
                       (_int = 1) ->
                       _void ->
                       v-out)))
  
  (define f64vector-sub
    (get-ffi-obj 'cblas_daxpy *blas*
                 (_fun (v1 v2) ::
                       (_int = (f64vector-length v1))
                       (_double = -1.0)
                       (_f64vector = v2)
                       (_int = 1)
                       (v-out : _f64vector = (f64vector-copy v1))
                       (_int = 1) ->
                       _void ->
                       v-out)))
  
  (define f64vector-dot
    (get-ffi-obj 'cblas_ddot *blas*
                 (_fun (v1 v2) ::
                       (_int = (f64vector-length v1))
                       (_f64vector = v1)
                       (_int = 1)
                       (_f64vector = v2)
                       (_int = 1) ->
                       _double))))