(module SO31-test mzscheme
(require "SO31.ss"
(planet "all.ss" ("wmfarr" "plt-linalg.plt" 1 2))
(planet "test.ss" ("schematics" "schemeunit.plt" 2))
(lib "42.ss" "srfi")
(planet "srfi-4-comprehensions.ss" ("wmfarr" "srfi-4-comprehensions.plt")))
(provide SO31-test-suite)
(define-simple-check (check-matrix-close? eps m1 m2)
(every?-ec (:parallel (:matrix x1 m1)
(:matrix x2 m2))
(< (abs (- x1 x2)) (abs eps))))
(define-simple-check (check-close? eps a b)
(< (abs (- a b)) (abs eps)))
(define-simple-check (check-second-order? a b)
(let ((ratio (abs (/ a b))))
(or (and (> ratio 3.9)
(< ratio 4.1))
(and (> ratio 0.45)
(< ratio 0.55)))))
(define eta (matrix 4 4
-1 0 0 0
0 1 0 0
0 0 1 0
0 0 0 1))
(define (eta-transform M)
(matrix-mul (matrix-mul (matrix-transpose M)
eta)
M))
(define (random-param)
(- (random) 0.5))
(define (random-params)
(f64vector-of-length-ec 6 (:range i 6) (random-param)))
(define (M-+-dM M dps dMs)
(fold-ec M (:parallel (:f64vector dp dps)
(:vector dM dMs))
(matrix-scale dM dp)
(lambda (dM M-acc)
(matrix-add dM M-acc))))
(define id (matrix-identity 4))
(define SO31-test-suite
(test-suite
"SO31.ss test suite"
(test-case
"Additive when only one parameter is non-zero"
(do-ec (:range i 6)
(let* ((p1 (f64vector-of-length-ec 6 (:range j 6) (if (= i j) (random) 0.0)))
(p2 (f64vector-of-length-ec 6 (:range j 6) (if (= i j) (random) 0.0)))
(p3 (f64vector-add p1 p2))
(M1 (params->matrix p1))
(M2 (params->matrix p2))
(M3 (params->matrix p3)))
(check-matrix-close? 1e-10 (matrix-mul M1 M2) M3))))
(test-case
"Double-check that boosts and rotations are addivite"
(do-ec (:list M (list Bx By Bz Rx Ry Rz))
(let* ((p (random))
(M1 (M p))
(M2 (M (- p))))
(check-matrix-close? 1e-10 (matrix-mul M1 M2) id))))
(test-case
"Each of the individual matrices preserves eta"
(do-ec (:list M (list Rx Ry Rz Bx By Bz))
(let* ((p (random-param))
(M (M p)))
(check-matrix-close? 1e-10 eta (eta-transform M)))))
(test-case
"Full LT Matrix Preserves eta"
(do-ec (:range i 6)
(let* ((p (random-params))
(M (params->matrix p)))
(check-matrix-close? 1e-10 eta (eta-transform M)))))
(test-case
"params->inverse-params for single non-zero parameters"
(do-ec (:range i 6)
(let* ((p (f64vector-of-length-ec 6 (:range j 6) (if (= i j) (random) 0.0)))
(pinv (f64vector-scale p -1.0))
(pinv2 (params->inverse-params p)))
(do-ec (:parallel (:f64vector pinv pinv)
(:f64vector pinv2 pinv2))
(check-close? 1e-10 pinv pinv2)))))
(test-case
"matrix->params and params->matrix are inverses"
(do-ec (:range i 6)
(let* ((p (random-params))
(M (params->matrix p))
(p2 (matrix->params M))
(M2 (params->matrix p2)))
(check-matrix-close? 1e-10 M M2))))
(test-case
"params->inverse params agrees with matrix inverse"
(do-ec (:range i 6)
(let* ((p (random-params))
(M (params->matrix p))
(Minv (matrix-inverse M))
(pinv (params->inverse-params p))
(Mpinv (matrix->params Minv)))
(do-ec (:parallel (:f64vector pinv pinv)
(:f64vector Mpinv Mpinv))
(check-close? 1e-10 pinv Mpinv)))))
(test-case
"derivative gives linear approximation to increment"
(do-ec (:range i 6)
(let* ((eps 1e-6)
(p0 (random-params))
(p-inc (random-params))
(dp1 (f64vector-scale p-inc eps))
(dp2 (f64vector-scale p-inc (/ eps 2)))
(p1 (f64vector-add dp1 p0))
(p2 (f64vector-add dp2 p0))
(M0 (params->matrix p0))
(M1 (params->matrix p1))
(M2 (params->matrix p2))
(dMs-matrix (matrix->dM M0))
(dMs-params (params->dM p0))
(approx-M1-matrix (M-+-dM M0 dp1 dMs-matrix))
(approx-M2-matrix (M-+-dM M0 dp2 dMs-matrix))
(approx-M1-params (M-+-dM M0 dp1 dMs-params))
(approx-M2-params (M-+-dM M0 dp2 dMs-params))
(e1-matrix (matrix-norm (matrix-sub approx-M1-matrix M1)))
(e2-matrix (matrix-norm (matrix-sub approx-M2-matrix M2)))
(e1-params (matrix-norm (matrix-sub approx-M1-params M1)))
(e2-params (matrix-norm (matrix-sub approx-M2-params M2))))
(check-second-order? e1-matrix e2-matrix)
(check-second-order? e1-params e2-params)))))))