#lang racket/base
(require
(for-syntax
srfi/26
racket/base
racket/list
racket/struct-info
racket/syntax)
srfi/26
racket/promise
racket/set
"encoding.rkt"
"reflection.rkt")
(define-syntax (define-primitive-type stx)
(syntax-case stx ()
[(define-primitive-type name type
reader writer)
(with-syntax ([primitive:info
(datum->syntax stx (format-symbol "primitive:~a" #'name))])
#'(define primitive:info
(primitive-info 'name 'type reader writer)))]))
(define-primitive-type int32 int*
read-int* write-int*)
(define-primitive-type int64 int*
read-int* write-int*)
(define-primitive-type uint32 int*
read-uint* write-uint*)
(define-primitive-type uint64 int*
read-uint* write-uint*)
(define (primitive:uint* max-size)
(primitive-info
'uint* 'int*
(cut read-uint* <> max-size)
(cut write-uint* <> <> max-size)))
(define-primitive-type sint32 int*
read-sint* write-sint*)
(define-primitive-type sint64 int*
read-sint* write-sint*)
(define (primitive:sint* max-size)
(primitive-info
'sint* 'int*
(cut read-sint* <> max-size)
(cut write-sint* <> <> max-size)))
(define-primitive-type fixed32 32bit
read-fixed32 write-fixed32)
(define-primitive-type fixed64 64bit
read-fixed64 write-fixed64)
(define-primitive-type sfixed32 32bit
read-sfixed32 write-sfixed32)
(define-primitive-type sfixed64 64bit
read-sfixed64 write-sfixed64)
(define-primitive-type bool int*
read-bool write-bool)
(define-primitive-type float 32bit
read-float write-float)
(define-primitive-type double 64bit
read-double write-double)
(define-primitive-type bytes sized
read-sized-bytes write-sized-bytes)
(define-primitive-type string sized
read-sized-string write-sized-string)
(provide
primitive:int32 primitive:int64
primitive:uint32 primitive:uint64 primitive:uint*
primitive:sint32 primitive:sint64 primitive:sint*
primitive:fixed32 primitive:fixed64
primitive:sfixed32 primitive:sfixed64
primitive:bool
primitive:float primitive:double
primitive:bytes primitive:string)
(define-syntax (define-enum-type stx)
(syntax-case stx ()
[(define-enum-type name
([alt tag] ...))
(with-syntax ([enum:info
(datum->syntax stx (format-symbol "enum:~a" #'name))]
[integer->enum
(datum->syntax stx (format-symbol "integer->~a" #'name))]
[enum->integer
(datum->syntax stx (format-symbol "~a->integer" #'name))])
#'(begin
(define (integer->enum i)
(case i
[(tag) 'alt] ...
[else (error 'integer->enum "unknown enumeration tag: ~e" i)]))
(define (enum->integer s)
(case s
[(alt) tag] ...
[else (error 'enum->integer "unknown enumeration value: ~e" s)]))
(define enum:info
(enum-info 'name integer->enum enum->integer))))]))
(define-syntax required-tags
(syntax-rules (required)
[(required-tags (required tag) . more)
(cons tag (required-tags . more))]
[(required-tags (_ tag) . more)
(required-tags . more)]
[(required-tags)
null]))
(define-syntax (define-message-type stx)
(syntax-case stx ()
[(define-message-type name
([label type field tag . default] ...))
(let ([fields (syntax-e #'(field ...))]
[defaults (syntax-e #'(default ...))])
(with-syntax ([nfields
(length fields)]
[(field-arg ...)
(append*
(for/list ([s (in-list fields)] [v (in-list defaults)])
(list
(string->keyword (symbol->string (syntax->datum s)))
(list s #'(void)))))]
[constructor
(datum->syntax stx (format-symbol "~a*" #'name))]
[(field-default ...)
(for/list ([v (in-list defaults)])
(if (null? (syntax->datum v)) #'((void)) v))]
[(accessor ...)
(datum->syntax
stx (map (cut format-symbol "~a-~a" #'name <>)
fields))]
[(mutator ...)
(datum->syntax
stx (map (cut format-symbol "set-~a-~a!" #'name <>)
fields))])
#'(begin
(struct name message
([field #:auto] ...)
#:transparent
#:mutable
#:auto-value (void)
#:property prop:protobuf
(delay
(message-info
'name constructor
(make-immutable-hasheqv
(list
(cons
tag
(field-info
type (memq 'label '(repeated packed)) (eq? 'label 'packed)
accessor mutator))
...))
(list->seteqv
(required-tags (label tag) ...)))))
(define (constructor field-arg ...)
(let ([msg (name (hasheqv))])
(unless (void? field) (mutator msg field)) ...
msg))
(set! accessor
(let ([message-ref accessor]
[default-thunk
(case 'label
[(required)
(λ () (error 'accessor "missing required field"))]
[(optional)
(λ () . field-default)]
[(repeated packed)
null])])
(λ (msg [alt default-thunk])
(let ([v (message-ref msg)])
(cond
[(not (void? v)) v]
[(procedure? alt) (alt)]
[else alt])))))
...)))]))
(define-syntax (define-message-extension stx)
(syntax-case stx ()
[(define-message-type name
[label type field tag . default])
(with-syntax ([(struct:info _ message? _ _ _)
(extract-struct-info (syntax-local-value #'name))]
[field-default
(if (null? (syntax->datum #'default)) #'((void)) #'default)]
[accessor
(datum->syntax
stx (format-symbol "~a-~a" #'name #'field))]
[mutator
(datum->syntax
stx (format-symbol "set-~a-~a!" #'name #'field))])
#'(begin
(define accessor
(let ([default-thunk
(case 'label
[(required)
(λ () (error 'accessor "missing required field"))]
[(optional)
(λ () . field-default)]
[(repeated)
null])])
(λ (msg [alt default-thunk])
(unless (message? msg)
(raise-type-error 'accessor (symbol->string name) msg))
(let ([v (hash-ref (message-extensions msg) tag void)])
(cond
[(not (void? v)) v]
[(procedure? alt) (alt)]
[else alt])))))
(define (mutator msg v)
(unless (message? msg)
(raise-type-error 'mutator (symbol->string name) msg))
(set-message-extensions! msg
(hash-set (message-extensions msg) tag v)))
(let ([info (force (protobuf-ref struct:info))])
(set-message-info-fields! info
(hash-set (message-info-fields info) tag
(field-info
type (memq 'label '(repeated packed)) (eq? 'label 'packed)
accessor mutator)))
(set-message-info-required! info
(set-union
(message-info-required info)
(list->seteqv
(required-tags (label tag))))))))]))
(provide
define-enum-type
define-message-type
define-message-extension)