dict.ss
#lang scheme

(require "define.ss" "contract.ss")

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  "Missing" Functions
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define-if-unbound dict-has-key?
  (let ()
    (with-contract
     dict-has-key?
     ([dict-has-key? (-> dict? any/c boolean?)])
     (define (dict-has-key? dict key)
       (let/ec return
         (dict-ref dict key (lambda () (return #f)))
         #t)))
    dict-has-key?))

(define-if-unbound dict-ref!
  (let ()
    (with-contract
     dict-ref!
     ([dict-ref! (-> (and/c dict? dict-mutable?)
                     any/c
                     (or/c (-> any/c) any/c)
                     any/c)])
     (define (dict-ref! dict key failure)
       (dict-ref
        dict key
        (lambda ()
          (let* ([value (if (procedure? failure) (failure) failure)])
            (dict-set! dict key value)
            value)))))
    dict-ref!))

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Ref Wrappers
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define (dict-ref/check dict key)
  (dict-ref dict key))

(define (dict-ref/identity dict key)
  (dict-ref dict key (lambda () key)))

(define (dict-ref/default dict key default)
  (dict-ref dict key (lambda () default)))

(define (dict-ref/failure dict key failure)
  (dict-ref dict key (lambda () (failure))))

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Extra Accessors
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define (dict-domain dict)
  (for/list ([i (in-dict-keys dict)]) i))

(define (dict-range dict)
  (for/list ([i (in-dict-values dict)]) i))

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Union
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define ((dict-duplicate-error name) key value1 value2)
  (error name "duplicate values for key ~e: ~e and ~e" key value1 value2))

(define (dict-union #:combine [combine (dict-duplicate-error 'dict-union)]
                    one . rest)
  (for*/fold ([one one]) ([two (in-list rest)] [(k v) (in-dict two)])
    (dict-set one k (if (dict-has-key? one k)
                        (combine k (dict-ref one k) v)
                        v))))

(define (dict-union! #:combine [combine (dict-duplicate-error 'dict-union!)]
                     one . rest)
  (for* ([two (in-list rest)] [(k v) (in-dict two)])
    (dict-set! one k (if (dict-has-key? one k)
                         (combine k (dict-ref one k) v)
                         v))))

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Contracted Dictionaries
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; A CDict is (make-contracted-dictionary (Listof (Cons Proj Proj)) Dict)
;; A Proj is (make-projection Contract Symbol Symbol Any Any)
(define-struct contracted-dictionary [projections bindings])
(define-struct projection [contract out in source name])

(define (dict/c key/c value/c)
  (let* ([key/c (coerce-contract 'dict/c key/c)]
         [value/c (coerce-contract 'dict/c value/c)])
    (make-proj-contract
     (build-compound-type-name 'dict/c key/c value/c)
     (lambda (pos neg src name)
       (lambda (dict)
         (unless (dict? dict)
           (raise-contract-error dict src pos name
                                 "expected a dictionary, got: ~e"
                                 dict))
         (wrap
          (cons (cons (make-projection key/c pos neg src name)
                      (make-projection value/c pos neg src name))
                (dict->projections dict))
          (dict->bindings dict))))
     dict?)))

(define-match-expander cdict
  (syntax-rules () [(_ p b) (struct contracted-dictionary [p b])]))

(define-match-expander proj
  (syntax-rules () [(_ c o i s n) (struct projection [c o i s n])]))

(define -ref
  (case-lambda
    [(dict key)
     (match dict
       [(cdict projs binds)
        (let* ([key (key-in projs key)])
          (value-out projs (dict-ref binds key)))])]
    [(dict key failure)
     (match dict
       [(cdict projs binds)
        (let* ([key (key-in projs key)])
          (if (dict-has-key? binds key)
              (value-out projs (dict-ref binds key))
              (if (procedure? failure) (failure) failure)))])]))

(define (-set! dict key value)
  (match dict
    [(cdict projs binds)
     (dict-set! binds (key-in projs key) (value-in projs value))]))

(define (-set dict key value)
  (match dict
    [(cdict projs binds)
     (wrap projs (dict-set binds (key-in projs key) (value-in projs value)))]))

(define (-rem! dict key)
  (match dict
    [(cdict projs binds)
     (dict-remove! binds (key-in projs key))]))

(define (-rem dict key)
  (match dict
    [(cdict projs binds)
     (wrap projs (dict-remove binds (key-in projs key)))]))

(define (-size dict)
  (match dict
    [(cdict projs binds)
     (dict-count binds)]))

(define (-fst dict)
  (match dict
    [(cdict projs binds)
     (dict-iterate-first binds)]))

(define (-nxt dict iter)
  (match dict
    [(cdict projs binds)
     (dict-iterate-next binds iter)]))

(define (-key dict iter)
  (match dict
    [(cdict projs binds)
     (key-out projs (dict-iterate-key binds iter))]))

(define (-val dict iter)
  (match dict
    [(cdict projs binds)
     (value-out projs (dict-iterate-value binds iter))]))

(define (key-in projs key)
  (if (null? projs)
      key
      (key-in (cdr projs) (project-in (caar projs) key))))

(define (value-in projs value)
  (if (null? projs)
      value
      (value-in (cdr projs) (project-in (cdar projs) value))))

(define (key-out projs key)
  (if (null? projs)
      key
      (project-out (caar projs) (key-out (cdr projs) key))))

(define (value-out projs value)
  (if (null? projs)
      value
      (project-out (caar projs) (value-out (cdr projs) value))))

(define (project-in p x)
  (match p [(proj c o i s n) (contract c x i o s)]))

(define (project-out p x)
  (match p [(proj c o i s n) (contract c x o i s)]))

(define (dict->bindings dict)
  (match dict
    [(cdict projs binds) binds]
    [_ dict]))

(define (dict->projections dict)
  (match dict
    [(cdict projs binds) projs]
    [_ null]))

(define (wrap dict projs binds)
  ((dict->wrapper dict) projs binds))

(define (dict->wrapper dict)
  (if (dict-mutable? dict)
      (if (dict-can-functional-set? dict)
          (if (dict-can-remove-keys? dict) make-:!+- make-:!+_)
          (if (dict-can-remove-keys? dict) make-:!_- make-:!__))
      (if (dict-can-functional-set? dict)
          (if (dict-can-remove-keys? dict) make-:_+- make-:_+_)
          (if (dict-can-remove-keys? dict) make-:__- make-:___))))

;; The __- case (removal without functional or mutable update) is nonsensical.
(define prop:!+- (vector -ref -set! -set -rem! -rem -size -fst -nxt -key -val))
(define prop:!+_ (vector -ref -set! -set  #f    #f  -size -fst -nxt -key -val))
(define prop:!_- (vector -ref -set!  #f  -rem!  #f  -size -fst -nxt -key -val))
(define prop:!__ (vector -ref -set!  #f   #f    #f  -size -fst -nxt -key -val))
(define prop:_+- (vector -ref  #f   -set  #f   -rem -size -fst -nxt -key -val))
(define prop:_+_ (vector -ref  #f   -set  #f   -rem -size -fst -nxt -key -val))
(define prop:__- (vector -ref  #f    #f   #f    #f  -size -fst -nxt -key -val))
(define prop:___ (vector -ref  #f    #f   #f    #f  -size -fst -nxt -key -val))

;; The __- case (removal without functional or mutable update) is nonsensical.
(define-struct (:!+- contracted-dictionary) [] #:property prop:dict prop:!+-)
(define-struct (:!+_ contracted-dictionary) [] #:property prop:dict prop:!+_)
(define-struct (:!_- contracted-dictionary) [] #:property prop:dict prop:!_-)
(define-struct (:!__ contracted-dictionary) [] #:property prop:dict prop:!__)
(define-struct (:_+- contracted-dictionary) [] #:property prop:dict prop:_+-)
(define-struct (:_+_ contracted-dictionary) [] #:property prop:dict prop:_+_)
(define-struct (:__- contracted-dictionary) [] #:property prop:dict prop:__-)
(define-struct (:___ contracted-dictionary) [] #:property prop:dict prop:___)

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Exports
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(provide dict-has-key? dict-ref!)
(provide/contract
 [dict/c (-> contract? contract? contract?)]
 [dict-ref/identity (-> dict? any/c any/c)]
 [dict-ref/default (-> dict? any/c any/c any/c)]
 [dict-ref/failure (-> dict? any/c (-> any/c) any/c)]
 [dict-ref/check
  (->d ([table dict?] [key any/c]) ()
       #:pre-cond (dict-has-key? table key)
       [_ any/c])]
 [dict-domain (-> dict? list?)]
 [dict-range (-> dict? list?)]
 [dict-union (->* [(and/c dict? dict-can-functional-set?)]
                  [#:combine (-> any/c any/c any/c any/c)]
                  #:rest (listof dict?)
                  (and/c dict? dict-can-functional-set?))]
 [dict-union! (->* [(and/c dict? dict-mutable?)]
                   [#:combine (-> any/c any/c any/c any/c)]
                   #:rest (listof dict?)
                   void?)])