(module array-ec mzscheme
(require (lib "42.ss" "srfi")
(lib "25.ss" "srfi")
(only (lib "1.ss" "srfi") fold))
(provide (all-from (lib "42.ss" "srfi"))
(all-from (lib "25.ss" "srfi"))
array-ec :array)
(define-syntax array-ec
(syntax-rules ()
((array-ec shp expr expr2 ... final-expr)
(let ((a (make-array shp)))
(let ((a-rank (array-rank a)))
(let ((sizes (list-ec (:range i a-rank) (- (array-end a i)
(array-start a i)))))
(let ((index-map (lambda (i)
(apply values
(car (fold
(lambda (size indices-and-i)
(let ((indices (car indices-and-i))
(i (cdr indices-and-i)))
(cons (cons (modulo i size) indices)
(quotient i size))))
(cons '() i)
(reverse sizes)))))))
(let ((shared-a (share-array a (shape 0 (product-ec (:list size sizes) size)) index-map)))
(let ((i 0))
(do-ec expr expr2 ... (begin (array-set! shared-a i final-expr)
(set! i (+ i 1))))
a)))))))))
(define (make-index-generator arr)
(let ((r (array-rank arr)))
(let ((r-1 (- r 1))
(lbs (vector-of-length-ec r (:range i r) (array-start arr i)))
(ubs (vector-of-length-ec r (:range i r) (array-end arr i))))
(let ((idx-v (vector-of-length-ec r (:vector lb lbs) lb)))
(vector-set! idx-v r-1 (- (vector-ref idx-v r-1) 1))
(lambda ()
(vector-set! idx-v r-1 (+ (vector-ref idx-v r-1) 1))
(let loop ((i r-1))
(cond
((= i 0) (if (>= (vector-ref idx-v 0) (vector-ref ubs 0))
'done
idx-v))
((= (vector-ref idx-v i) (vector-ref ubs i))
(vector-set! idx-v i (vector-ref lbs i))
(let ((i-1 (- i 1)))
(vector-set! idx-v i-1 (+ (vector-ref idx-v i-1) 1))
(loop i-1)))
(else idx-v))))))))
(define-syntax :array
(syntax-rules (index)
((:array cc x (index k0 ...) arr-expr)
(:do cc
(let ((arr arr-expr)
(gen #f))
(set! gen (make-index-generator arr)))
((idx-v (gen)))
(not (eq? idx-v 'done))
(let ((i 0)
(x (array-ref arr idx-v))
(k0 #f)
...)
(begin (set! k0 (vector-ref idx-v i))
(set! i (+ i 1)))
...)
#t
((gen))))
((:array cc x arr-expr)
(:array cc x (index) arr-expr)))))