semantics/wrap-prim.rkt
#lang racket

#|

File: semantics/wrap-prim.rkt
Author: Bill Turtle (wrturtle)

|#

(require plai/datatype)
(require "../utilities.rkt")
(provide (all-defined-out))

(define-type Arity
  [at-least (n positive?)]
  [at-most (n positive?)]
  [exactly (n positive?)]
  [variable])

(define-struct/contract sig ((name string?) (proc procedure?)) #:transparent)

(define any-sig (sig "anything" any/c))
(define integer-sig (sig "integer" integer?))
(define nonnegative-sig (sig "nonnegative number" (lambda (c) (and (number? c) (>= c 0)))))
(define real-sig (sig "real number" real?))
(define rat-sig (sig "rational number" rational?))
(define number-sig (sig "number" number?))
(define boolean-sig (sig "boolean" boolean?))
(define list-sig (sig "list" list?))
(define string-sig (sig "string" string?))
(define real-positive-sig (sig "real positive number" (and/c real? positive?)))

(define (3-equal? one two three)
  (and (equal? one two) (equal? two three)))

(define (gen-args-list desired-size l-of-contracts)
  (let loop ([l l-of-contracts]
             [s desired-size]
             [so-far empty]
             [most-recent #f])
    (if (zero? s)
        (reverse so-far)
        (cond
          [(empty? l)
           (loop l (sub1 s) (cons most-recent so-far) most-recent)]
          [else
           (let ([f (first l)]
                 [r (rest l)])
             (loop r (sub1 s) (cons f so-far) f))]))))

(define (check-arg-values name args l-of-sigs locs)
  (let ([numargs (length args)])
    (let ([l-to-use (gen-args-list numargs l-of-sigs)])
      (unless (equal? (length l-to-use) numargs)
        (error 'check-arg-values
               "pyret: internal error: number of contract do not match number of args: ~a ~a"
               (length l-to-use)
               numargs))
      (for-each (check-arg name) args l-to-use locs))))

(define (check-arg n)
  (lambda (a s l)
    (unless (sig? s)
      (raise-pyret-error
        "pyret: internal error: check-arg needs a signature as the second arg"
        l))
    (let ([p (sig-proc s)])
      (if (p a)
          a
          (raise-pyret-error (msg-unexpected n (sig-name s) a) l)))))

(define (check-arity name)
  (lambda (given expected loc)
    (type-case Arity expected
      [at-least (n)
        (if (>= given n)
            (void)
            (raise-pyret-error
              (msg-arity name "at least" n given)
              loc))]
      [at-most (n)
        (if (<= given n)
            (void)
            (raise-pyret-error
              (msg-arity name "at most" n given)
              loc))]
      [exactly (n)
        (if (= given n)
            (void)
            (raise-pyret-error
              (msg-arity name "exactly" n given)
              loc))]
      [variable () (void)])))

(define-syntax (wrap stx)
  (syntax-case stx ()
    [(_ name function-to-use arity contract)
     (syntax/loc stx
       (wrap name name function-to-use arity contract))]
    [(_ name name-to-use function-to-use arity contract)
     (syntax/loc stx
       (define name
         (lambda args
           (let ([numargs (length args)]
                 [locs (app-locations-first)])
            (validate-app-locs locs)
            (when (list? locs)
              ; check the arity
              ((check-arity (quote name-to-use)) numargs arity (first locs))
              ; check the argument values
              (if (equal? numargs (length (rest locs)))
                  (check-arg-values (quote name-to-use) args contract (rest locs))
                  (error 'wrap
                         "pyret: internal error: location list does not match the size of the args list; length of marks: ~a"
                         (length (continuation-mark-set->list (current-continuation-marks) 'my-app-locs)))))
            ; looks like we can call the inner function
            (apply function-to-use args)))))]))