language/infix.rkt
#lang racket

(require (only-in "sequences.rkt"
                  sequence+
                  sequence*))
(require (prefix-in pysem: "../semantics/beginner-funs.rkt"))
(require (prefix-in pysem: "../semantics/beginner-syntax.rkt"))
(require (prefix-in pysem: (only-in "../semantics/hash-percents.rkt"
                                    #%app)))
(require "../utilities.rkt")
(require "error-msgs.rkt")

(require (for-syntax "../utilities.rkt"))
(require (for-syntax "indent.rkt"))
(require (for-syntax "error-msgs.rkt"))

;; +
;;
;; There are three cases to worry about
;; <num> + <num>
;; + <num>
;; and
;; <sequence> + <sequence>
;;
;; We just punt off to sequence+ if we don't see numbers,
;; and let it worry about error reporting.
(define-syntax (pypar:+ stx)
  (syntax-case stx ()
    [(_)
     (pypar-syntax-error MSG-PLUS-NO-ARGS stx)]
    [(op_ e_)
     (let ((op (syntax op_))
           (e (syntax e_)))
       (check-indent 'SLGC op e)
       (with-syntax ([oppos (syntax->vector op)]
                     [epos (syntax->vector e)])
         (syntax/loc stx
           (do-unary '+ e_ oppos epos))))]
    [(op l_ r_)
     (let ([l (syntax l_)]
           [r (syntax r_)])
       (check-indent 'SLGC l (syntax op))
       (check-indent 'SLGC (syntax op) r)
       (with-syntax ([lpos (syntax->vector l)]
                     [rpos (syntax->vector r)])
         (syntax/loc stx
           (pypar:do+ l_ r_ lpos rpos))))]))

(define (pypar:do+ l r lpos rpos)
  (if (number? l)
      (if (number? r)
          ; this is sort of a hack, but since we have already made sure
          ; that l and r are numbers, and there are two arguments,
          ; pysem:+ will not need to use the locs (which of course,
          ; refers to _this_ loc).
          (pysem:#%app pysem:+ l r)
          (raise-pyret-error (msg-unexpected '+ "a number" r)
                             rpos))
      (sequence+ l r lpos rpos)))

;; Thankfully, we only have to worry about - on numbers
(define-syntax (pypar:- stx)
  (syntax-case stx ()
    [(_)
     (pypar-syntax-error MSG-MINUS-NO-ARGS
                         stx)]
    [(op_ arg_)
     (let ((op (syntax op_))
           (arg (syntax arg_)))
       (check-indent 'SLGC op arg)
       (with-syntax ([oppos (syntax->vector op)]
                     [argpos (syntax->vector arg)])
         (syntax/loc stx
           (do-unary '- arg_ oppos argpos))))]
    [(op l_ r_)
     (let ([l (syntax l_)]
           [r (syntax r_)])
       (check-indent 'SLGC l (syntax op))
       (check-indent 'SLGC (syntax op) r)
       (with-syntax ([lpos (syntax->vector l)]
                     [rpos (syntax->vector r)])
         (syntax/loc stx
           (pypar:do- l_ r_ lpos rpos))))]))

(define (pypar:do- l r lpos rpos)
  (if (number? l)
      (if (number? r)
          (pysem:#%app pysem:- l r)
          (raise-pyret-error (msg-unexpected '- "a number" r)
                             rpos))
      (raise-pyret-error (msg-unexpected '- "a number" l)
                         lpos)))

(define (do-unary op arg oppos argpos)
  (let ([proc (case op
                [(+) +]
                [(-) -]
                [else (error 'do-unary "function not recognized: ~e" op)])])
    (if (number? arg)
        (proc arg)
        (raise-pyret-error (format (string-append "~a: when one argument is given, "
                                                  "it must be a number")
                                   op)
                           oppos))))

;; Of all arithmetic operations, * is the trickiest.
;; * is also a sequence operator, and one of its operands is supposed
;; to be a number. So we have to check both operands to determine what we
;; should do.
(define-syntax (pypar:* stx)
  (syntax-case stx ()
    [(_)
     (pypar-syntax-error MSG-TIMES-NO-ARGS
                         stx)]
    [(op_ l_ r_)
     (let ((op (syntax op_))
           (l (syntax l_))
           (r (syntax r_)))
       (check-indent 'SLGC l op)
       (check-indent 'SLGC op r)
       (with-syntax ([oppos (syntax->vector op)]
                     [lpos (syntax->vector l)]
                     [rpos (syntax->vector r)])
         (syntax/loc stx
           (pypar:do* l_ r_ lpos rpos oppos))))]))

(define (pypar:do* l r lpos rpos oppos)
  (cond
    [(and (number? l)
          (number? r))
     (pysem:#%app pysem:* l r)]
    [(and (number? l)
          (sequence? r))
     (sequence* l r)]
    [(and (number? r)
          (sequence? l))
     (sequence* r l)]
    [else
     (raise-pyret-error "*: this operator must be applied to numbers or sequences"
                        oppos)]))

;; Division is straightforward -- we just have to watch out for division by zero
(define-syntax (pypar:/ stx)
  (syntax-case stx ()
    [(_)
     (pypar-syntax-error MSG-DIVIDE-NO-ARGS
                         stx)]
    [(op_ l_ r_)
     (let ((op (syntax op_))
           (l (syntax l_))
           (r (syntax r_)))
       (check-indent 'SLGC l op)
       (check-indent 'SLGC op r)
       (with-syntax ([wholeloc (syntax->vector stx)]
                     [lloc (syntax->vector l)]
                     [rloc (syntax->vector r)])
         (syntax/loc stx
           (pypar:do-divide l_ r_ wholeloc lloc rloc))))]))

(define (pypar:do-divide l r wholeloc lloc rloc)
  (if (not (number? l))
      (raise-pyret-error (msg-unexpected '/ "a number" l)
                         lloc)
      (if (not (number? r))
          (raise-pyret-error (msg-unexpected '/ "a number" r)
                             rloc)
          (if (zero? r)
              (raise-pyret-error MSG-DIVISION-BY-ZERO
                                 wholeloc)
              (pysem:#%app pysem:/ l r)))))

(define-syntax (pypar:% stx)
  (syntax-case stx ()
    [(_) (pypar-syntax-error MSG-MOD-NO-ARGS
                             stx)]
    [(op_ l_ r_)
     (let ((op (syntax op_))
           (l (syntax l_))
           (r (syntax r_)))
       (check-indent 'SLGC l op)
       (check-indent 'SLGC op r)
       (with-syntax ([wholeloc (syntax->vector stx)]
                     [lloc (syntax->vector l)]
                     [rloc (syntax->vector r)])
         (syntax/loc stx
           (pypar:do% l_ r_ wholeloc lloc rloc))))]))
(define (pypar:do% l r wloc lloc rloc)
  (check-argument-values '% (list (vector l number? "a number" lloc)
                                  (vector r number? "a number" rloc)))
  (if (zero? r)
      (raise-pyret-error MSG-MOD-BY-ZERO
                         wloc)
      (pysem:#%app pysem:modulo l r)))

(define-syntax (pypar:** stx)
  (syntax-case stx ()
    [(_) (pypar-syntax-error MSG-EXPT-NO-ARGS
                             stx)]
    [(op_ l_ r_)
     (let ((op (syntax op_))
           (l (syntax l_))
           (r (syntax r_)))
       (check-indent 'SLGC l op)
       (check-indent 'SLGC op r)
       (with-syntax ([lloc (syntax->vector l)]
                     [rloc (syntax->vector r)])
         (syntax/loc stx
           (pypar:do** l_ r_ lloc rloc))))]))
(define (pypar:do** l r lloc rloc)
  (check-argument-values '** (list (vector l number? "a number" lloc)
                                   (vector r number? "a number" rloc)))
  (pysem:#%app pysem:expt l r))


(provide (rename-out [pypar:+ +]
                     [pypar:- -]
                     [pypar:* *]
                     [pypar:/ /]
                     [pypar:% %]
                     [pypar:** **]))


;; ---------------------------------------------------------------------------
;; `not', `and', and `or'

(define-syntax (pypar:not stx)
  (syntax-case stx ()
    [id
     (identifier? #'id)
     (pypar-syntax-error MSG-NOT-NO-ARGS
                         stx)]
    [(kw)
     (pypar-syntax-error MSG-NOT-NO-ARGS
                         stx)]
    [(kw expr)
     (check-indent 'SLGC #'kw #'expr)
     (syntax/loc stx
       (pysem:#%app pysem:not expr))]))

;; use this macro to generate `and' and `or'
(define-syntax (make-binary-op stx)
  (syntax-case stx ()
    [(_ macro-name function-name real-name semantics error-message)
     (syntax/loc stx
       (begin
         (define-syntax (macro-name stx)
           (syntax-case stx ()
             [(_)
              (pypar-syntax-error error-message stx)]
             [(_ expr)
              (pypar-syntax-error error-message stx)]
             [(kw e1 e2)
              (begin
                #;(check-indent 'SLGC #'e1 #'kw)
                #;(check-indent 'SLGC #'kw #'e2)
                (with-syntax ([lpos (syntax->vector #'e1)]
                              [rpos (syntax->vector #'e2)])
                  (syntax/loc stx
                    (function-name e1 e2 lpos rpos))))]
             [(_ e1 e2 e3 (... ...))
              ;; this case is automatically generated
              (syntax/loc stx
                (semantics e1 e2 e3 (... ...)))]))
         (define (function-name l r lloc rloc)
           (cond
             [(boolean? l)
              (cond
                [(boolean? r)
                 (semantics l r)]
                [else
                 (raise-pyret-error (string-append real-name
                                                   ": the expression to the right "
                                                   "should evaluate to a boolean "
                                                   "value")
                                    rloc)])]
             [else
              (raise-pyret-error (string-append real-name
                                                ": the expression to the left "
                                                "should evaluate to a boolean "
                                                "value")
                                 lloc)]))))]))

(make-binary-op pypar:and pypar:do-and "and" pysem:and MSG-AND-WRONG-ARGS)
(make-binary-op pypar:or pypar:do-or "or" pysem:or MSG-OR-WRONG-ARGS)

(define-for-syntax (make-boolean-binary-arity-error op)
  (string-append op ": expected an expression before and after the `"
                 op "' keyword"))



;; use this macro for < <= = >= >
(define-syntax (make-boolean-comp-op stx)
  (syntax-case stx ()
    [(_ macro-name function-name num-semantics string-semantics op error-message)
     (syntax/loc stx
       (begin
         (define-syntax (macro-name stx1)
           (syntax-case stx1 ()
             [(_)
              (pypar-syntax-error error-message stx1)]
             [(_ e)
              (pypar-syntax-error error-message stx1)]
             [(kw e1 e2)
              (begin
                #;(check-indent 'SLGC #'e1 #'kw)
                #;(check-indent 'SLGC #'e1 #'e2)
                (with-syntax ([lpos (syntax->vector #'e1)]
                              [rpos (syntax->vector #'e2)])
                  (syntax/loc stx1
                    (function-name e1 e2 lpos rpos))))]))
         (define (function-name l r lpos rpos)
           (cond
             [(number? l)
              (cond
                [(number? r)
                 (pysem:#%app num-semantics l r)]
                [else
                 (raise-pyret-error (string-append op
                                                   ": the expression to the right "
                                                   "of the \"" op "\" sign "
                                                   "should evaluate to a number")
                                    rpos)])]
             [(string? l)
              (cond
                [(string? r)
                 (pysem:#%app string-semantics l r)]
                [else
                  (raise-pyret-error (string-append op
                                                    ": the expression to the right "
                                                    "of the \"" op "\" sign "
                                                    "should evaluate to a string")
                                     rpos)])]
             [else
              (raise-pyret-error (string-append op
                                                ": the expression to the left "
                                                "of the \"" op "\" sign "
                                                "should evaluate to a number or "
                                                "a string")
                                 lpos)]))))]))
              

(make-boolean-comp-op pypar:< pypar:do< pysem:< pysem:string_lt "<" (make-boolean-binary-arity-error "<"))
(make-boolean-comp-op pypar:<= pypar:do<= pysem:<= pysem:string_leq "<=" (make-boolean-binary-arity-error "<="))
#;(make-boolean-comp-op pypar:= pypar:do= pysem:= pysem:string_equal "=" (make-boolean-binary-arity-error "="))
#;(make-boolean-comp-op pypar:!= pypar:do!= (lambda (l r) (pysem:not (pysem:= l r))) (lambda (l r) (pysem:not (pysem:string_equal l r))) "!=" (make-boolean-binary-arity-error "!="))
(make-boolean-comp-op pypar:>= pypar:do>= pysem:>= pysem:string_geq ">=" (make-boolean-binary-arity-error ">="))
(make-boolean-comp-op pypar:> pypar:do> pysem:> pysem:string_gt ">" (make-boolean-binary-arity-error ">"))

(define pypar:= (lambda (o t) (pysem:#%app pysem:equal o t)))
(define pypar:!= (lambda (o t) (not (pysem:#%app pysem:equal o t))))

(provide (rename-out [pypar:not not]
                     [pypar:and and]
                     [pypar:or or]
                     [pypar:< <]
                     [pypar:<= <=]
                     [pypar:= =]
                     [pypar:!= !=]
                     [pypar:>= >=]
                     [pypar:> >]))