#lang racket
(require (for-syntax "../utilities.rkt"))
(require "../utilities.rkt")
(require "first-order.rkt")
(require (only-in lang/posn posn?))
(require (only-in 2htdp/image image? pen? color?))
(provide pyret-struct-instance? pyret-struct-lookup)
(provide beginner-define-struct define-struct)
(provide pyret-struct?)
(provide struct-name struct-field-id-list)
(define (struct-name s)
(pyret-struct-instance-struct-id s))
(define (struct-field-id-list s)
(pyret-struct-instance-field-ids s))
(define-syntax (beginner-define-struct stx)
(syntax-case stx ()
[(_ rest ...)
(syntax/loc stx
(pyret-define-struct #t rest ...))]))
(define-syntax (define-struct stx)
(syntax-case stx ()
[(_ rest ...)
(syntax/loc stx
(pyret-define-struct #f rest ...))]))
(define-struct/contract pyret-struct-instance
([struct-id symbol?]
[fields hash?]
[field-ids list?])
#:transparent )
(define (pyret-struct-lookup ps sym sl)
(let ([s-id (pyret-struct-instance-struct-id ps)]
[fields (pyret-struct-instance-fields ps)])
(let ([result (hash-ref fields sym (lambda () #f))])
(if result
result
(raise-pyret-error (format (string-append "structure lookup: struct ~a "
"has no field named ~a")
s-id
sym)
sl)))))
(define (pyret-struct-create-instance s-id lof-fields)
(let ([ht (make-hash)])
(for-each
(λ (kons) (hash-set! ht (car kons) (cdr kons)))
lof-fields)
(pyret-struct-instance s-id ht (map car lof-fields))))
(define-for-syntax (check-in-top-level stx)
(unless (memq (syntax-local-context) '(top-level module module-begin))
(raise-pyret-error/stx
(string-append "structure definitions are only "
"allowable at the top level")
stx)))
(define-syntax (pyret-define-struct stx)
(check-in-top-level stx)
(syntax-case stx ()
[(_ first-order? name . __)
(not (identifier/non-kw? (syntax name)))
(raise-syntax-error #f
(format "\"~a\" is not a valid structure name\n"
(syntax-e (syntax name)))
stx)]
[(_ first-order?_ name_ (field_ ...) . rest)
(let ([name (syntax name_)]
[fields (syntax->list (syntax (field_ ...)))]
[ht (make-hash)])
(for-each
(lambda (field)
(unless (identifier? field)
(raise-syntax-error #f
"invalid field name"
stx
field))
(let ([sym (syntax-e field)])
(when (hash-ref ht sym (lambda () #f))
(raise-syntax-error #f
"found a field name used more than once"
stx
field))
(hash-set! ht sym #t)))
fields)
(let ([rest (syntax->list (syntax rest))])
(unless (null? rest)
(raise-syntax-error #f
"You don't need anything after the field names"
stx
rest)))
(with-syntax ([pred (datum->syntax (syntax name_) (string->symbol
(string-append "is_"
(symbol->string
(syntax->datum
(syntax name_))))))])
(if (syntax-e (syntax first-order?_))
(syntax/loc stx
(begin
(define-first-order name_
(lambda args
(let ([field-list '(field_ ...)]
[args-length (length args)])
(let ([field-list-length (length field-list)])
(if (equal? field-list-length args-length)
(let ([fields (map cons field-list args)])
(pyret-struct-create-instance (quote name_) fields))
(raise-arity-error (quote name_) field-list-length args))))))
(define-first-order pred
(lambda (strct)
(if (pyret-struct-instance? strct)
(equal? (pyret-struct-instance-struct-id strct)
(quote name_))
#f)))))
(syntax/loc stx
(begin
(define name_
(lambda args
(let ([field-list '(field_ ...)]
[args-length (length args)])
(let ([field-list-length (length field-list)])
(if (equal? field-list-length args-length)
(let ([fields (map cons field-list args)])
(pyret-struct-create-instance (quote name_) fields))
(raise-arity-error (quote name_) field-list-length args))))))
(define pred
(lambda (strct)
(if (pyret-struct-instance? strct)
(equal? (pyret-struct-instance-struct-id strct)
(quote name_))
#f))))))))]))
(define pyret-struct?
(lambda (v)
(or
(pyret-struct-instance? v)
(image? v)
(posn? v)
(color? v)
(pen? v))))