common/ad-hoc-typing.ss
#lang scheme
;; This module provides some type-related macros.

(require "point.ss")

(provide assert-type
         type-fun
         deftyped
         type-case-lambda
         
         ;; Typing utilities (to construct complex types)
         list-of
         pair-of
         )

;; assert-type src var val type type?
;;   src: symbol for source of the error.
;;   var: variable (name) to type-check
;;   val: value to type-check
;;   type: type of variable (name symbol for type? procedure)
;;   type?: Checking procedure
;(define (assert-type src var val type type?)
;  (if (not (type? val))
;      (error src "argument ~s is not of the correct type. Given: ~s. Wanted: ~s"
;             var val type)
;      #t))
(require (prefix-in s: (only-in scheme and or)))
(define-syntax make-predicate
  (syntax-rules (and or)
    [(_ (and) val) (error "predicate conjunction must have arguments!")]
    [(_ (or)  val) (error "predicate disjunction must have arguments!")]
    [(_ (and predicates ...) val)
     (s:and (make-predicate predicates val) ...)]
    [(_ (or predicates ...) val)
     (s:or (make-predicate predicates val) ...)]
    [(_ predicate val)
     (predicate val)]))

(define-syntax assert-type
  (syntax-rules ()
    [(_ src val type?)
     (let ([val* val])
       (if (not (make-predicate type? val*))
           (error 'src "argument ~s is not of the correct type. Given: ~s. Wanted: ~s"
                  'val val* 'type?)
           (void)))]))


;; Macro type-fun function [var type] ...
;;   function: symbol (name of the function to type-check
;;   var: variable name (to type-check)
;;   type: typing predicate
(define-syntax type-fun
  (syntax-rules ()
    [(_ (symbol [var type] ...))
     (define symbol
       (let ([old symbol])
         (lambda (var ...)
           (assert-type symbol var type) ...
           (old var ...))))]
    [(_ (symbol [var type] ...)
        stmts ...)
     (define (symbol var ...)
       (assert-type symbol var type) ...
       stmts ...)]))

(define-syntax deftyped
  (syntax-rules ()
    [(_ typed-fun (fun [var type] ...))
     (define typed-fun
       (lambda (var ...)
         (assert-type fun var type) ...
         (fun var ...)))]
    [(_ typed-fun (fun [var type] ...)
        stmts ...)
     (define typed-fun
       (lambda (var ...)
         (assert-type fun var type) ...
         stmts ...))]))


(define-syntax type-case-lambda
  (syntax-rules ()
    [(_)
     (error (quote case-lambda) "No arguments")]
    [(_
      (a1 e1 ...)
      clauses ...)
     (lambda args
       (let ([l (length args)])
         (tcase-lambda-clause args l
                              (a1 e1 ...)
                              clauses ...)))]))

(define-syntax tcase-lambda-clause
  (syntax-rules ()
    [(_ args l
        [([a1 t1] ...) e1 ...]
        clauses ...)
     (if (= l (length '(a1 ...)))
         (apply (lambda (a1 ...)
                  (assert-type 'case-lambda a1 t1) ...
                  e1 ...)
                args)
         (tcase-lambda-clause args l
                              clauses ...))]
    [(_ args l
        ([a1 t1] e1 ...)
        clauses ...)
     (let ([a1 args])
       (assert-type 'case-lambda a1 (list-of t1))
       e1 ...)]
    [(_ args l)
     (error (quote type-case-lambda) "Wrong number of arguments to function.")]))

;([type-case-params (syntax-rules ()
;                                        [(_ fun (([arg type] ...) stmts ...))
;                                         (arg ...)]
;                                        [(_ fun ([args type] stmts ...))
;                                         args])]
;                    [type-case-body (syntax-rules ()
;                                      [(_ fun (([arg type] ...) stmts ...))
;                                       (begin
;                                         (assert-type fun arg type) ...
;                                         stmts ...)]
;                                      [(_ fun ([args type] stmts ...))
;                                       (begin
;                                         (assert-type fun args (list-of type))
;                                         stmts ...)])]
;                    [build-cases (lambda (stx)
;                                   (syntax-case stx ()
;                                     [(_ acc case)
;                                      #`[#,#'(type-case-params case)
;                                         #,#'(type-case-body case)]]
;                                     [(_ acc case cases ...)
;                                      #`(acc
;                                         . ([#,#'(type-case-params case)
;                                             #,#'(type-case-body case)]))]))])
;; Test
;(define (a b c d)
;    (string-append (symbol->string b)
;                   (number->string (+ c d))))
;(deftyped a* (a [b symbol?] [c positive?] [d (or positive? negative?)]))


;; Typing utilities:
;; Let's assume all lists are homogeneous (faster ;-) )
(define ((list-of predicate) o)
  (and (list? o)
       (predicate (first o))))

(define ((pair-of predicate) o)
  (and (pair? o)
       (predicate (car o))))


(define a
  (type-case-lambda
   [([a integer?] [b number?] [c symbol?])
    (display a)(display b)(display c)]
   [[args number?]
    (display args)]))