find-optimal-join.ss
(module find-optimal-join mzscheme
  (require (lib "etc.ss")
           (lib "contract.ss")
           (planet "comprehensions.ss" ("dyoo" "srfi-alias.plt" 1))
           (planet "contract-utils.ss" ("cobbe" "contract-utils.plt" 3 0)))
  
  
  ;; make-table: number number -> table
  ;; Create a table whose values are initially all +inf.0.
  (define (make-table m n)
    (build-vector
     m
     (lambda (i)
       (build-vector
        n
        (lambda (j)
          +inf.0)))))
  
  
  ;; table-ref: table number number -> any
  ;; Gets the (i,j)th entry in the table.
  (define (table-ref a-table i j)
    (vector-ref (vector-ref a-table i) j))
  
  
  ;; table-set!: table number number any -> void
  ;; Sets the (i,j)th entry in the table.
  (define (table-set! a-table i j val)
    (vector-set! (vector-ref a-table i) j val))
  
  
  ;; A little helper to do the math for running along the diagonals.
  ;; along-diagonals: natural-number (natural-number natural-number -> void) -> void
  ;; Applies a function along the diagonals of a table.
  ;; i.e. for N=3, we'll run through (i, j) = [(0, 0), (1, 1), (2, 2),
  ;;                                           (0, 1), (1, 2),
  ;;                                           (0, 2)]
  (define (along-diagonals N f)
    (do-ec (:range m 0 N)
           (:range n 0 (- N m))
           (let ([i n]
                 [j (+ m n)])
             (f i j))))
  
  
  ;; compute-recurrence: (vectorof X) (X -> number) (number number -> number) -> (values table table)
  ;;
  ;; Computes the optimal K values to solve a recurrence that
  ;; minimizes overall cost.
  ;;
  ;; Done with a dynamic programming technique.  Efficiency is O(n^3).
  ;; Idea: Let F[i, j] represent the minimum depth from joining
  ;; the t_i through t_j trees.  Then
  
  ;; F[i, j] = {
  ;;            depth[i] if i = j
  ;;
  ;;            otherwise:
  ;;            min on k of cost+(F[i, k], F[k, j]) where i <=k < j
  ;;           }
  ;;
  ;; We keep an auxillary data structure K which represents the
  ;; point where we do the partitioning.
  ;; We return both F and K:
  ;;
  ;; F[i,j]: the cost of the optimal solution incorporating the
  ;; elements from i to j, inclusive.
  ;;
  ;; K[i, j]: the value of k chosen to minimize F[i, j].  Used
  ;; to reconstruct the optimal solution.
  (define (compute-recurrence forest-vec initial-cost-f cost+)
    (local ((define N (vector-length forest-vec))
            
            (define F (make-table N N))
            (define K (make-table N N)))
      
      ;; Run across the diagonals.  Along the main
      ;; diagonal, we set F[i,i] to be the depths.
      ;; Along the rest, we choose the minimal partitioning.
      (along-diagonals
       N
       (lambda (i j)
         (cond [(= i j)
                (table-set! F i i
                            (initial-cost-f (vector-ref forest-vec i)))]
               [else
                (do-ec (:range k i j)
                       (let ([cost-at-k
                              (cost+ (table-ref F i k)
                                     (table-ref F (add1 k) j))])
                         (when (< cost-at-k (table-ref F i j))
                           (table-set! F i j cost-at-k)
                           (table-set! K i j k))))])))
      (values F K)))
  
  
  ;; concatenate-with-K: (vectorof X) (X X -> X) table -> X
  ;; Using the optimal partitioning information from K, concatenate
  ;; the nodes together.
  (define (concatenate-with-K forest-vec join-f K)
    (let loop ([i 0]
               [j (sub1 (vector-length forest-vec))])
      (cond
        [(= i j)
         (vector-ref forest-vec i)]
        [else
         (let ([k (table-ref K i j)])
           (join-f (loop i k)
                   (loop (add1 k) j)))])))
  
  
  ;; join-forest: (listof X) (X X -> X) (X -> natural-number) -> X
  (define (join-forest forest join-f depth-f)
    (local [(define (cost+ c1 c2)
              (add1 (max c1 c2)))]
      (join-forest/cost+ forest join-f depth-f cost+)))
  
  
  ;; join-forest/cost+: (listof X) (X X -> X) (X -> natural-number) (natural-number natural-number -> natural-number) -> X
  (define (join-forest/cost+ forest join-f initial-cost-f cost+)
    (local [(define forest-vec
              (list->vector forest))
            (define-values (F K)
              (compute-recurrence forest-vec initial-cost-f cost+))]
      (concatenate-with-K forest-vec join-f K)))
  
  
  (provide/contract [join-forest
                     ((nelistof/c any/c)
                      (any/c any/c . -> . any)
                      (any/c . -> . natural-number/c)
                      . -> . any)]
                    
                    [join-forest/cost+
                     ((nelistof/c any/c)
                      (any/c any/c . -> . any)
                      (any/c . -> . natural-number/c)
                      (natural-number/c natural-number/c
                                        . -> . natural-number/c)
                      . -> . any)]))