contract.ss
#lang scheme

(require (for-syntax syntax/parse))

#|

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Contract Properties
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define-struct contract-impl [proj name first-order stronger])

(define (build-contract-impl proj
                             [name (default-contract-name proj)]
                             [first-order default-contract-first-order]
                             [stronger default-contract-stronger])
  (make-contract-impl proj name first-order stronger))

(define (default-contract-name c) (format "~a" c))
(define (default-contract-first-order c) (lambda (x) #t))
(define (default-contract-stronger a b) (eq? a b))

(define (contract-guard x y)
  (if (contract-impl? x)
      x
      (error 'contract-guard "invalid contract implementation: ~e" x)))

(define-values [ contract-prop contract-struct? contract->impl ]
  (make-struct-type-property
   'contract
   contract-guard
   (list (cons proj-prop contract-impl-proj)
         (cons name-prop contract-impl-name)
         (cons stronger-prop contract-impl-stronger)
         (cons first-order-prop contract-impl-first-order))))

(define-struct (flat-contract-impl contract-impl) [pred])

(define (build-flat-contract-impl first-order
                                  [name default-contract-name]
                                  [stronger default-contract-stronger])
  (make-flat-contract-impl (first-order->proj first-order)
                           name
                           first-order
                           stronger
                           first-order))

(define ((((first-order->proj first-order) c) pos neg src name where?) x)
  (if ((first-order c) x)
      x
      (raise-contract-error
       x src pos name
       "expected a/an ~a; got: ~e"
       name x)))

(define (flat-contract-guard x y)
  (if (flat-contract-impl? x)
      x
      (error 'flat-contract-guard
             "invalid flat contract implementation: ~e"
             x)))

(define (flat-contract-call i)
  (let ([get-pred (flat-contract-impl-pred i)])
    (lambda (c x)
      ((get-pred c) x))))

(define-values [ flat-contract-prop flat-contract-struct? flat-contract->impl ]
  (make-struct-type-property
   'flat-contract
   flat-contract-guard
   (list (cons contract-prop values)
         (cons flat-prop flat-contract-impl-pred)
         (cons prop:procedure flat-contract-call))))

|#

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Flat Contracts
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

#|

(define contract/predicate-impl
  (build-flat-contract-impl
   (lambda (c/p) (contract/predicate-test c/p))
   (lambda (c/p) (contract/predicate-name c/p))))

(define-struct contract/predicate [ name test ]
  #:property flat-contract-prop contract/predicate-impl)

(define contract-pred?
  (make-contract/predicate
   "<flat contract and predicate>"
   (lambda (x)
     (and (flat-contract? x)
          (procedure? x)
          (procedure-arity-includes? x 1)))))

(define nat?
  (make-contract/predicate "natural number" exact-nonnegative-integer?))

(define pos?
  (make-contract/predicate "positive integer" exact-positive-integer?))

(define truth?
  (make-contract/predicate "truth value" (lambda (x) #t)))

|#

(define nat/c
  (flat-named-contract '|natural number| exact-nonnegative-integer?))

(define pos/c
  (flat-named-contract '|positive integer| exact-positive-integer?))

(define truth/c
  (flat-named-contract '|truth value| (lambda (x) #t)))

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Function Contracts
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define thunk/c (-> any/c))
(define unary/c (-> any/c any/c))
(define binary/c (-> any/c any/c any/c))
(define predicate/c (-> any/c boolean?))
(define comparison/c (-> any/c any/c boolean?))
(define predicate-like/c (-> any/c truth/c))
(define comparison-like/c (-> any/c any/c truth/c))

#|

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Polymorphic Contracts
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define-syntax (poly->/c stx)
  (syntax-parse stx
    [(_ (~optional (~seq #:gen gen:expr)
                   #:defaults ([gen #'memory/c]))
        [var:id ...]
        body:expr)
     #'(make-poly gen '(var ...) (lambda (var ...) body))]))

(define poly-proj
  (match-lambda
    [(struct poly [gen vars body])

     (lambda (pos neg src name position)
       (lambda (proc)

         (define (wrap-proc)

           (define c
             (coerce-contract
              'poly->/c
              (apply body (map (lambda (var) (gen var position)) vars))))

           ((((proj-get c) c) pos neg src name position) proc))

         (unless (procedure? proc)
           (raise-contract-error
            proc src pos name
            "expected a procedure, got: ~e"
            proc))

         (make-keyword-procedure
          (lambda (keys vals . args)
            (keyword-apply (wrap-proc) keys vals args))
          (case-lambda
            [() ((wrap-proc))]
            [(a) ((wrap-proc) a)]
            [(a b) ((wrap-proc) a b)]
            [(a b c) ((wrap-proc) a b c)]
            [(a b c d) ((wrap-proc) a b c d)]
            [(a b c d e) ((wrap-proc) a b c d e)]
            [(a b c d e f) ((wrap-proc) a b c d e f)]
            [(a b c d e f g) ((wrap-proc) a b c d e f g)]
            [(a b c d e f g h) ((wrap-proc) a b c d e f g h)]
            [args (apply (wrap-proc) args)]))))]))

(define (poly-name poly)
  `(poly->/c ,(poly-vars poly) ...))

(define (poly-stronger poly c) #f)

(define ((poly-first-order poly) v) #t)

(define poly-contract
  (make-contract-impl poly-proj poly-name poly-first-order poly-stronger))

(define-struct poly [gen vars body]
  #:property contract-prop poly-contract)

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Anaphoric Contracts
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define (memory/c [name 'memory/c] [position #t])
  (make-memory name position (make-hasheq)))

(define (memory-proj memory)
  (lambda (pos neg src name position)
    (if (boolean=? (memory-position memory) position)
        (lambda (x)
          (if (hash-has-key? (memory-table memory) x)
              x
              (raise-contract-error
               x src pos name
               "expected a/an ~a, got: ~e"
               (memory-name memory)
               x)))
        (lambda (x)
          (hash-set! (memory-table memory) x #t)
          x))))

(define (memory-stronger memory c) #f)

(define ((memory-first-order memory) v) #t)

(define memory-contract
  (make-contract-impl memory-proj
                      (lambda (x) (memory-name x))
                      memory-first-order
                      memory-stronger))

(define-struct memory [name position table]
  #:property contract-prop memory-contract)

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Parametric (Protect) Contracts
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define (protect/c [name 'protect/c] [position #t])

  (define-values [ type make pred getter setter ]
    (make-struct-type name #f 1 0))

  (make-protect name position make pred (lambda (x) (getter x 0))))

(define (protect-proj protect)
  (lambda (pos neg src name position)
    (if (boolean=? (protect-position protect) position)
        (lambda (x)
          (if (protect-pred x)
              (protect-get x)
              (raise-contract-error
               x src pos name
               "expected a/an ~a; got: ~e"
               (protect-name protect)
               x)))
        (lambda (x)
          (protect-make x)))))

(define (protect-stronger protect c) #f)

(define ((protect-first-order protect) v) #t)

(define protect-contract
  (make-contract-impl protect-proj
                      (lambda (x) (protect-name x))
                      protect-first-order
                      protect-stronger))

(define-struct protect [name position make pred get]
  #:property contract-prop protect-contract)

|#

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Contracted Sequences
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define (sequence/c . elem/cs)
  (let* ([elem/cs (for/list ([elem/c (in-list elem/cs)])
                    (coerce-contract 'sequence/c elem/c))]
         [n-cs (length elem/cs)])
    (make-proj-contract
     (apply build-compound-type-name 'sequence/c elem/cs)
     (lambda (pos neg src name blame)
       (lambda (seq)
         (unless (sequence? seq)
           (raise-contract-error
            seq src pos name
            "expected a sequence, got: ~e"
            seq))
           (make-do-sequence
            (lambda ()
              (let*-values ([(more? next) (sequence-generate seq)])
                (values
                 (lambda (idx)
                   (call-with-values next
                     (lambda elems
                       (define n-elems (length elems))
                       (unless (= n-elems n-cs)
                         (raise-contract-error
                          seq src pos name
                          "expected a sequence of ~a values, got ~a values: ~s"
                          n-cs n-elems elems))
                       (apply
                        values
                        (for/list ([elem (in-list elems)]
                                   [elem/c (in-list elem/cs)])
                          ((((proj-get elem/c) elem/c) pos neg src name blame) elem))))))
                 (lambda (idx) idx)
                 #f
                 (lambda (idx) (more?))
                 (lambda (elem) #t)
                 (lambda (idx elem) #t)))))))
     sequence?)))

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Contracted Dictionaries
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; A CDict is (make-contracted-dictionary (Listof (Cons Proj Proj)) Dict)
;; A Proj is (make-projection Contract Symbol Symbol Any Any)
(define-struct contracted-dictionary [projections bindings])
(define-struct projection [contract out in source name blame])

(define (dict/c key/c value/c)
  (let* ([key/c (coerce-contract 'dict/c key/c)]
         [value/c (coerce-contract 'dict/c value/c)])
    (make-proj-contract
     (build-compound-type-name 'dict/c key/c value/c)
     (lambda (pos neg src name blame)
       (lambda (dict)
         (unless (dict? dict)
           (raise-contract-error dict src pos name
                                 "expected a dictionary, got: ~e"
                                 dict))
         (wrap
          (cons (cons (make-projection key/c pos neg src name blame)
                      (make-projection value/c pos neg src name blame))
                (dict->projections dict))
          (dict->bindings dict))))
     dict?)))

(define-match-expander cdict
  (syntax-rules () [(_ p b) (struct contracted-dictionary [p b])]))

(define-match-expander proj
  (syntax-rules () [(_ c o i s n b) (struct projection [c o i s n b])]))

(define -ref
  (case-lambda
    [(dict key)
     (match dict
       [(cdict projs binds)
        (let* ([key (key-in projs key)])
          (value-out projs (dict-ref binds key)))])]
    [(dict key failure)
     (match dict
       [(cdict projs binds)
        (let* ([key (key-in projs key)])
          (let/ec return
            (define (fail)
              (return (if (procedure? failure) (failure) failure)))
            (value-out projs (dict-ref binds key fail))))])]))

(define (-set! dict key value)
  (match dict
    [(cdict projs binds)
     (dict-set! binds (key-in projs key) (value-in projs value))]))

(define (-set dict key value)
  (match dict
    [(cdict projs binds)
     (wrap projs (dict-set binds (key-in projs key) (value-in projs value)))]))

(define (-rem! dict key)
  (match dict
    [(cdict projs binds)
     (dict-remove! binds (key-in projs key))]))

(define (-rem dict key)
  (match dict
    [(cdict projs binds)
     (wrap projs (dict-remove binds (key-in projs key)))]))

(define (-size dict)
  (match dict
    [(cdict projs binds)
     (dict-count binds)]))

(define (-fst dict)
  (match dict
    [(cdict projs binds)
     (dict-iterate-first binds)]))

(define (-nxt dict iter)
  (match dict
    [(cdict projs binds)
     (dict-iterate-next binds iter)]))

(define (-key dict iter)
  (match dict
    [(cdict projs binds)
     (key-out projs (dict-iterate-key binds iter))]))

(define (-val dict iter)
  (match dict
    [(cdict projs binds)
     (value-out projs (dict-iterate-value binds iter))]))

(define (key-in projs key)
  (if (null? projs)
      key
      (key-in (cdr projs) (project-in (caar projs) key))))

(define (value-in projs value)
  (if (null? projs)
      value
      (value-in (cdr projs) (project-in (cdar projs) value))))

(define (key-out projs key)
  (if (null? projs)
      key
      (project-out (caar projs) (key-out (cdr projs) key))))

(define (value-out projs value)
  (if (null? projs)
      value
      (project-out (cdar projs) (value-out (cdr projs) value))))

(define (project-in p x)
  (match p
    [(proj c o i s n b)
     ((((proj-get c) c) i o s n (not b)) x)]))

(define (project-out p x)
  (match p
    [(proj c o i s n b)
     ((((proj-get c) c) o i s n b) x)]))

(define (dict->bindings dict)
  (match dict
    [(cdict projs binds) binds]
    [_ dict]))

(define (dict->projections dict)
  (match dict
    [(cdict projs binds) projs]
    [_ null]))

(define (wrap projs binds)
  ((dict->wrapper binds) projs binds))

(define (dict->wrapper dict)
  (if (dict-mutable? dict)
      (if (dict-can-functional-set? dict)
          (if (dict-can-remove-keys? dict) make-:!+- make-:!+_)
          (if (dict-can-remove-keys? dict) make-:!_- make-:!__))
      (if (dict-can-functional-set? dict)
          (if (dict-can-remove-keys? dict) make-:_+- make-:_+_)
          (if (dict-can-remove-keys? dict) make-:__- make-:___))))

;; The __- case (removal without functional or mutable update) is nonsensical.
(define prop:!+- (vector -ref -set! -set -rem! -rem -size -fst -nxt -key -val))
(define prop:!+_ (vector -ref -set! -set  #f    #f  -size -fst -nxt -key -val))
(define prop:!_- (vector -ref -set!  #f  -rem!  #f  -size -fst -nxt -key -val))
(define prop:!__ (vector -ref -set!  #f   #f    #f  -size -fst -nxt -key -val))
(define prop:_+- (vector -ref  #f   -set  #f   -rem -size -fst -nxt -key -val))
(define prop:_+_ (vector -ref  #f   -set  #f   -rem -size -fst -nxt -key -val))
(define prop:__- (vector -ref  #f    #f   #f    #f  -size -fst -nxt -key -val))
(define prop:___ (vector -ref  #f    #f   #f    #f  -size -fst -nxt -key -val))

;; The __- case (removal without functional or mutable update) is nonsensical.
(define-struct (:!+- contracted-dictionary) [] #:property prop:dict prop:!+-)
(define-struct (:!+_ contracted-dictionary) [] #:property prop:dict prop:!+_)
(define-struct (:!_- contracted-dictionary) [] #:property prop:dict prop:!_-)
(define-struct (:!__ contracted-dictionary) [] #:property prop:dict prop:!__)
(define-struct (:_+- contracted-dictionary) [] #:property prop:dict prop:_+-)
(define-struct (:_+_ contracted-dictionary) [] #:property prop:dict prop:_+_)
(define-struct (:__- contracted-dictionary) [] #:property prop:dict prop:__-)
(define-struct (:___ contracted-dictionary) [] #:property prop:dict prop:___)

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;
;;  Exports
;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(provide/contract

 #|
 [contract-pred? contract-pred?]

 [nat? contract-pred?]
 [pos? contract-pred?]
 [truth? contract-pred?]
 |#

 [nat/c flat-contract?]
 [pos/c flat-contract?]
 [truth/c flat-contract?]

 [thunk/c contract?]
 [unary/c contract?]
 [binary/c contract?]
 [predicate/c contract?]
 [comparison/c contract?]
 [predicate-like/c contract?]
 [comparison-like/c contract?]

 #|
 [contract-prop struct-type-property?]
 [contract-struct? (-> any/c boolean?)]
 [contract-impl? (-> any/c boolean?)]
 [rename
  build-contract-impl make-contract-impl
  (->* [(-> contract-struct?
            (-> symbol? symbol? (or/c syntax? #f) string? boolean?
                (-> any/c
                    any/c)))]
       [(-> contract-struct? printable/c)
        (-> contract-struct? (-> any/c boolean?))
        (-> contract-struct? contract? boolean?)]
       contract-impl?)]

 [flat-contract-prop struct-type-property?]
 [flat-contract-struct? (-> any/c boolean?)]
 [flat-contract-impl? (-> any/c boolean?)]
 [rename
  build-flat-contract-impl make-flat-contract-impl
  (->* [(-> flat-contract-struct? (-> any/c boolean?))]
       [(-> flat-contract-struct? printable/c)
        (-> flat-contract-struct? contract? boolean?)]
       flat-contract-impl?)]

 [make-contract/predicate (-> string? predicate/c contract-pred?)]

 [memory/c (->* [] [symbol? boolean?] contract?)]
 [protect/c (->* [] [symbol? boolean?] contract?)]
 |#

 #|
 [sequence/c (->* [] [] #:rest (listof contract?) contract?)]
 |#
 [dict/c (-> contract? contract? contract?)]
 )

#|
(provide poly->/c)
|#