semantics/structures.rkt
#lang racket

#|

File: semantics/structures.rkt
Author: Bill Turtle (wrturtle)

Contains the implementation of structures for Pyret.

|#

(require (for-syntax "../utilities.rkt"))
(require "../utilities.rkt")
(require "first-order.rkt")
(require (only-in lang/posn posn?))
(require (only-in 2htdp/image image? pen? color?))

(provide print-struct pyret-struct-instance? pyret-struct-lookup)
(provide beginner-define-struct define-struct)
(provide pyret-struct?)

(define-syntax (beginner-define-struct stx)
  (syntax-case stx ()
    [(_ rest ...)
     (syntax/loc stx
       (pyret-define-struct #t rest ...))]))
(define-syntax (define-struct stx)
  (syntax-case stx ()
    [(_ rest ...)
     (syntax/loc stx
       (pyret-define-struct #f rest ...))]))

(define-struct/contract pyret-struct-instance
  ([struct-id symbol?]
   [fields hash?]
   [field-ids list?])
  #:transparent ; it bugs me that this isn't the default
  )
;; pyret-struct-lookup: pyret-struct-instance * symbol * srcloc -> any
;;
;; Wraps around the hash table reference.
;; The `sl' field is used for better error reporting
(define (pyret-struct-lookup ps sym sl)
  (let ([s-id (pyret-struct-instance-struct-id ps)]
        [fields (pyret-struct-instance-fields ps)])
    (let ([result (hash-ref fields sym (lambda () #f))])
      (if result
          result
          (raise-pyret-error (format (string-append "structure lookup: struct ~a "
                                                    "has no field named ~a")
                                     s-id
                                     sym)
                             sl)))))

;; pyret-struct-create: symbol * (listof (consof symbol any)) -> pyret-struct-instance
;;
;; Create a pyret-struct-instance struct, where `s-id' refers to the name of the
;; structure, and `lof-fields' is a list of the desired fields.
;;
;; N.B. This procedure assumes that it's arguments are valid (if not, expansion should not have
;;      made it to this point).
(define (pyret-struct-create-instance s-id lof-fields)
  (let ([ht (make-hash)])
    (for-each
     (λ (kons) (hash-set! ht (car kons) (cdr kons)))
     lof-fields)
    (pyret-struct-instance s-id ht lof-fields)))

;; checks to make sure that whichever piece of syntax calls this
;; is at the top level
(define-for-syntax (check-in-top-level stx)
  (unless (memq (syntax-local-context) '(top-level module module-begin))
    (raise-pyret-error/stx
     (string-append "structure definitions are only "
                    "allowable at the top level")
     stx)))

(define-syntax (pyret-define-struct stx)
  ;; Make sure that we are being used in the correct spot
  (check-in-top-level stx)
  (syntax-case stx ()
    ;; First, check for a valid struct name
    [(_ first-order? name . __)
     (not (identifier/non-kw? (syntax name)))
     (raise-syntax-error #f
                         (format "\"~a\" is not a valid structure name\n"
                                 (syntax-e (syntax name)))
                         stx)]
    ;; Main case
    [(_ first-order?_ name_ (field_ ...) . rest)
     (let ([name (syntax name_)]
           [fields (syntax->list (syntax (field_ ...)))]
           [ht (make-hash)])
       (for-each
        (lambda (field)
          (unless (identifier? field)
            (raise-syntax-error #f
                                "invalid field name"
                                stx
                                field))
          (let ([sym (syntax-e field)])
            (when (hash-ref ht sym (lambda () #f))
              (raise-syntax-error #f
                                  "found a field name used more than once"
                                  stx
                                  field))
            (hash-set! ht sym #t)))
        fields)
       (let ([rest (syntax->list (syntax rest))])
         (unless (null? rest)
           (raise-syntax-error #f
                               "You don't need anything after the field names"
                               stx
                               rest)))
       ;; Looks like we can go ahead and create the structure
       (with-syntax ([pred (datum->syntax (syntax name_) (string->symbol
                                               (string-append "is_"
                                                              (symbol->string
                                                               (syntax->datum
                                                                (syntax name_))))))])
         (if (syntax-e (syntax first-order?_))
             (syntax/loc stx
               (begin
                 (define-first-order name_
                   (lambda args
                     (let ([field-list '(field_ ...)]
                           [args-length (length args)])
                       (let ([field-list-length (length field-list)])
                         (if (equal? field-list-length args-length)
                             (let ([fields (map cons field-list args)])
                               (pyret-struct-create-instance (quote name_) fields))
                             (raise-arity-error (quote name_) field-list-length args))))))
                 (define-first-order pred
                   (lambda (strct)
                     (if (pyret-struct-instance? strct)
                         (equal? (pyret-struct-instance-struct-id strct)
                                 (quote name_))
                         #f)))))
             (syntax/loc stx
               (begin
                 (define name_
                   (lambda args
                     (let ([field-list '(field_ ...)]
                           [args-length (length args)])
                       (let ([field-list-length (length field-list)])
                         (if (equal? field-list-length args-length)
                             (let ([fields (map cons field-list args)])
                               (pyret-struct-create-instance (quote name_) fields))
                             (raise-arity-error (quote name_) field-list-length args))))))
                 (define pred
                   (lambda (strct)
                     (if (pyret-struct-instance? strct)
                         (equal? (pyret-struct-instance-struct-id strct)
                                 (quote name_))
                         #f))))))))]))
(define pyret-struct?
  (lambda (v)
    (or
      (pyret-struct-instance? v)
      (image? v)
      (posn? v)
      (color? v)
      (pen? v))))

(define/contract (print-struct p)
  (-> pyret-struct-instance? void?)
  (display (format "struct: ~a (" (pyret-struct-instance-struct-id p)))
  (let ([f-ids (pyret-struct-instance-field-ids p)])
    (unless (empty? f-ids)
      (display (format "~a = ~a" (car (car f-ids)) (pyret-struct-lookup p (car (car f-ids)) #f)))
      (let loop ([l (cdr f-ids)])
        (if (empty? l)
            (void)
            (begin
              (display (format ", ~a = ~a" (car (car l))
                                           (pyret-struct-lookup p (car (car l)) #f)))
              (loop (cdr l))))))
    (display ")")))