#| bh-tree.scm: Barnes-Hut octtrees.
Copyright (C) 2007 Will M. Farr

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

(module bh-tree mzscheme
  (require (planet "nbody-ics.scm" ("wmfarr" "nbody-ics.plt"))
           (lib "" "srfi")
           (lib "")
           (lib ""))
  (provide (struct cell (m q bounds sub-trees))
           (struct bounds (low high)))
  (define-struct cell
    (m q bounds sub-trees) #f)
  (define-struct bounds
    (low high) #f)
  (define empty-tree? null?)
  (define (tree? obj)
    (or (empty-tree? obj)
        (body? obj)
        (cell? obj)))
  (provide/contract (empty-tree? (-> tree? boolean?))
                    (make-empty (-> empty-tree?))
                    (tree? (-> any/c boolean?))
                    (tree-m (-> tree? (>=/c 0.0)))
                    (tree-q (-> tree? 3vector/c))
                    (tree-size-squared (-> tree? (>=/c 0.0)))
                    (in-bounds? (-> bounds? 3vector/c boolean?))
                    (bodies->tree (-> nbody-system/c tree?))
                    (tree-fold/sc (-> (-> tree? any) (-> tree? any/c any) any/c tree? any))
                    (tree-fold (-> (-> tree? any/c any) any/c tree? any)))
  (define (make-empty) '())
  (define (tree-m t)
      ((empty-tree? t) 0.0)
      ((body? t) (body-m t))
      ((cell? t) (cell-m t))
      (else (error 'tree-m "argument not a tree: ~a" t))))
  (define (tree-q t)
      ((empty-tree? t) (make-vector 3 0.0))
      ((body? t) (body-q t))
      ((cell? t) (cell-q t))
      (else (error 'tree-q "argument not a tree: ~a" t))))
  (define (in-bounds? bds v)
    (every?-ec (:parallel (:vector vx v)
                          (:vector lx (bounds-low bds))
                          (:vector ux (bounds-high bds)))
               (and (<= lx vx)
                    (< vx ux))))
  (define *epsilon* 1e-6)
  (define expand-bounds
    (let ((efactor (+ 1.0 *epsilon*)))
      (lambda (bds)
        (let ((low (bounds-low bds))
              (high (bounds-high bds)))
           (vector-of-length-ec (vector-length high)
             (:parallel (:vector lx low)
                        (:vector ux high))
             (let ((delta (- ux lx)))
               (+ lx (* delta efactor)))))))))
  (define (bodies->bounds bs)
    (let ((min (make-vector 3 +inf.0))
          (max (make-vector 3 -inf.0)))
      (do-ec (:vector b bs)
             (let ((q (body-q b)))
               (do-ec (:parallel (:vector qx (index i) q)
                                 (:vector minx min)
                                 (:vector maxx max))
                      (begin (if (< qx minx) (vector-set! min i qx))
                             (if (> qx maxx) (vector-set! max i qx))))))
      (expand-bounds (make-bounds min max))))
  (define (high? i j)
    (> (bitwise-and (arithmetic-shift 1 j) i) 0))
  (define (sub-bounds bds)
    (let ((low (bounds-low bds))
          (high (bounds-high bds)))
      (let ((mids (vector-of-length-ec 3 (:parallel (:vector lx low)
                                                    (:vector hx high))
                    (/ (+ lx hx) 2))))
        (vector-of-length-ec 8
          (:range i 8)
          (let ((sub-low (vector-of-length-ec 3 (:parallel (:vector lx (index j) low)
                                                           (:vector mx mids))
                           (if (high? i j) mx lx)))
                (sub-high (vector-of-length-ec 3 (:parallel (:vector mx (index j) mids)
                                                            (:vector hx high))
                            (if (high? i j) hx mx))))
            (make-bounds sub-low sub-high))))))
  (define (split-bodies bs bds)
    (let ((sub-bds (sub-bounds bds)))
      (vector-of-length-ec 8 (:vector sb sub-bds)
        (vector-ec (:vector b bs)
                   (if (in-bounds? sb (body-q b)))
  (define (trees-total-mass ts)
    (sum-ec (:vector t ts) (tree-m t)))
  (define (trees-center-of-mass ts)
    (let ((M (trees-total-mass ts))
          (com (make-vector 3 0.0)))
      (if (= M 0.0)
            (do-ec (:vector t ts)
                   (let ((m (tree-m t)))
                     (when (> m 0.0)
                       (let ((factor (/ m M))
                             (q (tree-q t)))
                         (do-ec (:parallel (:vector comx (index i) com)
                                           (:vector qx q))
                                (vector-set! com i (+ comx (* qx factor))))))))
  (define (bodies->tree bs)
      ((= (vector-length bs) 0)
      ((= (vector-length bs) 1)
       (vector-ref bs 0))
      (else (let ((bds (bodies->bounds bs)))
              (let ((sub-bs (split-bodies bs bds)))
                (let ((sub-trees (vector-of-length-ec 8 (:vector bs sub-bs)
                                   (bodies->tree bs))))
                  (make-cell (trees-total-mass sub-trees)
                             (trees-center-of-mass sub-trees)
  (define (tree-fold/sc cut? f start t)
      ((empty-tree? t) start)
      ((body? t) (f t start))
      ((cut? t) (f t start))
      (else (let ((sub-ts (cell-sub-trees t))
                  (tf (lambda (t acc) (tree-fold/sc cut? f acc t))))
              (fold-ec start (:vector st sub-ts) st tf)))))
  (define (tree-fold f start t)
    (tree-fold/sc (lambda (t) #f) f start t))
  (define (tree-size-squared t)
      ((empty-tree? t) 0.0)
      ((body? t) 0.0)
      (else (let ((bds (cell-bounds t)))
              (let ((low (bounds-low bds))
                    (high (bounds-high bds)))
                (sum-ec (:parallel (:vector lx low)
                                   (:vector hx high))
                        (sqr (- hx lx)))))))))