;;; PLT Scheme Science Collection
;;; random-distributions/discrete.ss
;;; Copyright (c) 2004-2006 M. Douglas Williams
;;; This library is free software; you can redistribute it and/or
;;; modify it under the terms of the GNU Lesser General Public
;;; License as published by the Free Software Foundation; either
;;; version 2.1 of the License, or (at your option) any later version.
;;; This library is distributed in the hope that it will be useful,
;;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;;; Lesser General Public License for more details.
;;; You should have received a copy of the GNU Lesser General Public
;;; License along with this library; if not, write to the Free
;;; Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
;;; 02111-1307 USA.
;;; -------------------------------------------------------------------
;;; This module implements general discrete distribution.
;;; Version  Date      Description
;;; 1.0.0    09/28/04  Marked as ready for Release 1.0.  Added
;;;                    contracts for functions.  (Doug Williams)
;;; 1.1.0    04/18/06  Made random-discrete use a binary search.
;;;                    (Doug Williams)
;;; 1.1.1    04/20/06  Changed random-discrete to use Walker's O(1)
;;;                    algorithm.  (Doug Williams)

(module discrete mzscheme
  ;; Data Definition
  ;; discrete structure
  ;; An instance of the discrete structure, created by make-discrete,
  ;; represents a general discrete distribution.  Only the discrete?
  ;; function is exported.
  (define-values (struct:discrete
    (make-struct-type 'discrete #f 5 0))
  ;; Contracts
  (require (lib "contract.ss"))
    (-> any/c boolean?))
    (-> (vectorof (>/c 0.0)) discrete?))
    (case-> (-> random-source? discrete? natural-number/c)
            (-> discrete? natural-number/c)))
    (-> discrete? integer? (real-in 0.0 1.0)))
    (-> discrete? integer? (real-in 0.0 1.0))))
  (require "../random-source.ss")
  (define discrete-n
    (make-struct-field-accessor discrete-field-ref 0 'n))
  (define discrete-a
    (make-struct-field-accessor discrete-field-ref 1 'a))
  (define discrete-f
    (make-struct-field-accessor discrete-field-ref 2 'f))
  (define discrete-p
    (make-struct-field-accessor discrete-field-ref 3 'p))
  (define discrete-c
    (make-struct-field-accessor discrete-field-ref 4 'c))
  ;; make-discrete: vector -> discrete
  ;; This function accepts a vector of weights and returns a discrete
  ;; probability structure that can be passed to random-discrete to
  ;; generate random variates.  The weights do not have to sum to 1.0.
  (define (make-discrete w)
    (let* ((n (vector-length w))
           (a (make-vector n))
           (f (make-vector n))
           (sum 0.0)
           (cumm 0.0)
           (mean (/ 1.0 n))
           (smalls '())
           (bigs '())
           (e (make-vector n))
           (p (make-vector n))
           (c (make-vector n)))
      ;; find sum
      (do ((i 0 (+ i 1)))
          ((= i n) (void))
        (let ((wi (vector-ref w i)))
          (set! sum (+ sum wi))))
      ;; normalize weights and partition into bigs and smalls
      ;; also compute pdf and cdf values
      (do ((i 0 (+ i 1)))
          ((= i n) (void))
        (let* ((wi (vector-ref w i))
               (q (/ wi sum)))
          ;; normalize
          (vector-set! e i q)
          ;; compute pdf
          (vector-set! p i q)
          ;; compute cdf
          (set! cumm (+ cumm q))
          (vector-set! c i cumm)
          ;; partition
          (if (< q mean)
              (set! smalls (cons i smalls))
              (set! bigs (cons i bigs)))))
      ;; work through the smalls
      (let loop ()
        (when (not (null? smalls))
          (let ((s (car smalls)))
            (set! smalls (cdr smalls))
            (if (null? bigs)
                  (vector-set! a s s)
                  (vector-set! f s 1.0)
                (let ((b (car bigs)))
                  (set! bigs (cdr bigs))
                  (vector-set! a s b)
                  (vector-set! f s (* n (vector-ref e s)))
                  (let ((d (- mean (vector-ref e s))))
                    (vector-set! e s (+ (vector-ref e s) d))
                    (vector-set! e b (- (vector-ref e b) d)))
                  (cond ((< (vector-ref e b) mean)
                         (set! smalls (cons b smalls)))
                        ((> (vector-ref e b) mean)
                         (set! bigs (cons b bigs)))
                         (vector-set! a b b)
                         (vector-set! f b 1.0))))))
      ;; work through remaining bigs
      (let loop ()
        (when (not (null? bigs))
          (let ((b (car bigs)))
            (set! bigs (cdr bigs))
            (vector-set! a b b)
            (vector-set! f b 1.0))
      ;; apply Knuth's convention
      (do ((i 0 (+ i 1)))
          ((= i n) (void))
        (vector-set! f i (/ (+ (vector-ref f i) i) n)))
      ;; return the discrete
      (discrete-constructor n a f p c)))
  ;; random-discrete: random-source x discrete -> integer
  ;; random-discrete: discrete -> integer
  ;; This function returns a random variate from the given discrete
  ;; distribution given by d.
  (define random-discrete
      ((r d)
       (let* ((u (random-uniform r))
              (c (inexact->exact (floor (* u (discrete-n d)))))
              (f (vector-ref (discrete-f d) c)))
         (if (= f 1.0)
             (if (< u f)
                 (vector-ref (discrete-a d) c)))))
       (random-discrete (current-random-source) d))))
  ;; discrete-pdf: discrete x integer -> real
  ;; This function computes the probability density p(k) at k for a
  ;; discrete distribution given by d.
  (define (discrete-pdf d k)
    (let* ((p (discrete-p d))
           (n (vector-length p)))
      (if (or (< k 0)
              (>= k n))
          (vector-ref p k))))
  ;; discrete-cdf: discrete x integer -> real
  ;; This function computes the cummulative density d(k) at k for a
  ;; discrete distribution given by d.
  (define (discrete-cdf d k)
    (let* ((c (discrete-c d))
           (n (vector-length c)))
      (cond ((< k 0)
            ((> k n)
             (vector-ref c k)))))