header.ss
#lang scheme/base

(require (prefix-in set: (planet dherman/set:3/set))
         (planet dherman/io:1:9/io)
         scheme/system
         scheme/file
         scheme/list
         scheme/match
         "syntax.ss"
         "eval.ss")
(require (for-syntax scheme/base)
         (for-syntax "parser.ss"))

(provide (struct-out abi))
(provide layout?
         ref-layout? struct-layout? union-layout? enum-layout? array-layout? pointer-layout?
         layout-size layout-offset
         struct-layout-lookup union-layout-lookup deref-layout)
(provide make-header header? (rename-out [build-header header]))
(provide abi-lookup abi-lookup-typedef abi-lookup-tag serialize-abi deserialize-abi write-abi read-abi)

;; I think these are too unsafe to export:
;; (provide abi-add-typedef abi-add-tag)

(provide compile-header system-compiler)
(provide (struct-out query:sizeof)
         (struct-out query:offset)
         (struct-out query:expr)
         query?)

;; =============================================================================
;; DATA TYPE DEFINITIONS
;; =============================================================================

(define-struct internal:layout (size) #:prefab)
(define-struct (internal:layout:ref internal:layout) (name) #:prefab)
(define-struct (internal:layout:struct internal:layout) (tag fields) #:prefab)
(define-struct (internal:layout:union internal:layout) (tag variants) #:prefab)
(define-struct (internal:layout:enum internal:layout) (tag variants) #:prefab)
(define-struct (internal:layout:array internal:layout) (type length) #:prefab)
(define-struct (internal:layout:pointer internal:layout) (type) #:prefab)

(define-struct layout (abi internal)
  #:transparent
  #:property prop:custom-write
  (lambda (wrapped port write?)
    (let ([internal (layout-internal wrapped)])
      (display
       (cond
         [(internal:layout:ref? internal) "#<ref-layout>"]
         [(internal:layout:struct? internal) "#<struct-layout>"]
         [(internal:layout:union? internal) "#<union-layout>"]
         [(internal:layout:enum? internal) "#<enum-layout>"]
         [(internal:layout:array? internal) "#<array-layout>"]
         [(internal:layout:pointer? internal) "#<pointer-layout>"]
         [else "#<layout>"])
       port))))

;; abi * (union x (layout x)) -> (layout x)
(define (wrap abi x)
  (make-layout abi (soft-unwrap x)))

;; (layout x) -> x
(define (unwrap x)
  (layout-internal x))

;; (union x (layout x)) -> x
(define (soft-unwrap x)
  (if (layout? x) (unwrap x) x))

;; (layout internal:layout) -> boolean
(define struct-layout? (compose internal:layout:struct? unwrap))
(define union-layout? (compose internal:layout:union? unwrap))
(define enum-layout? (compose internal:layout:enum? unwrap))
(define array-layout? (compose internal:layout:array? unwrap))
(define pointer-layout? (compose internal:layout:pointer? unwrap))
(define ref-layout? (compose internal:layout:ref? unwrap))

;; XXX: cycle detection

;; layout -> layout
(define (deref-layout l)
  (match l
    [(struct layout (abi (struct internal:layout:ref (_ name))))
     (deref-layout (abi-lookup abi name))]
    [_ l]))

(define-struct query:sizeof (type) #:transparent)
(define-struct query:offset (type field) #:transparent)
(define-struct query:expr (expr) #:transparent)
(define (query? x)
  (or (query:sizeof? x)
      (query:offset? x)
      (query:expr? x)))

(define-struct header (headers))

;; exported as `header'
(define build-header
  (procedure-rename (lambda headers
                      (make-header headers))
                    'header))

(define-struct abi (typedefs tags)
  #:property prop:procedure
  (lambda (this key)
    (abi-lookup this key)))

;; =============================================================================
;; ABI OPERATIONS
;; =============================================================================

;; TODO: do a dotted-path lookup

;; abi * symbol -> layout
(define (abi-lookup abi key)
  (wrap abi
        (hash-ref (abi-typedefs abi)
                  key
                  (lambda ()
                    (hash-ref (abi-tags abi)
                              key
                              (lambda ()
                                (error 'abi-lookup "no definition for ~a" key)))))))

;; abi * symbol -> layout
(define (abi-lookup-typedef abi key)
  (wrap abi
        (hash-ref (abi-typedefs abi)
                  key
                  (lambda ()
                    (error 'abi-lookup-typedef "no definition for ~a" key)))))

;; abi * symbol -> layout
(define (abi-lookup-tag abi key)
  (wrap abi
        (hash-ref (abi-tags abi)
                  key
                  (lambda ()
                    (error 'abi-lookup-tag "no definition for ~a" key)))))

;; abi * symbol * (union internal:layout layout) -> abi
(define (abi-add-typedef a name type)
  (struct-copy abi a [typedefs (hash-set (abi-typedefs a) name (soft-unwrap type))]))

;; abi * symbol * (union internal:layout layout) -> abi
(define (abi-add-tag a name type)
  (struct-copy abi a [tags (hash-set (abi-tags a) name (soft-unwrap type))]))

;; sexp -> abi
(define (deserialize-abi sexp)
  (match sexp
    [`(abi ,typedefs ,tags)
     (make-abi typedefs tags)]
    [_ (raise-syntax-error 'read-abi "invalid serialized ABI" sexp)]))

;; abi -> sexp
(define (serialize-abi a)
  (match a
    [(struct abi (typedefs tags))
     `(abi ,typedefs ,tags)]))

;; abi [* output-port] -> void
(define (write-abi abi [out (current-output-port)])
  (write (serialize-abi abi) out))

;; [input-port] -> abi
(define (read-abi [in (current-input-port)])
  (deserialize-abi (read in)))

;; =============================================================================
;; LAYOUT OPERATIONS
;; =============================================================================

;; layout -> uint
(define layout-size
  (procedure-rename (compose internal:layout-size unwrap) 'layout-size))

(define (split-javadot sym)
  (map string->symbol (regexp-split #rx"\\." (symbol->string sym))))

;; layout * (union symbol (listof symbol)) -> uint
(define (layout-offset layout path)
  (cond
    [(symbol? path)
     (layout-offset layout (split-javadot path))]
    [(null? path) 0]
    [(struct-layout? layout)
     (match-let ([(list _ offset layout*) (struct-layout-lookup (car path) layout)])
       (+ offset (layout-offset layout* (cdr path))))]
    [(union-layout? layout)
     (match-let ([(list _ layout*) (union-layout-lookup (car path) layout)])
       (layout-offset layout* (cdr path)))]
    [(ref-layout? layout)
     (layout-offset (deref-layout layout) path)]
    ;; XXX: enums
    [else (error 'layout-offset "not a layout: ~a" layout)]))

;; symbol * struct-layout -> layout
(define (struct-layout-lookup field-name layout)
  (cond
    [(assq field-name (internal:layout:struct-fields (unwrap layout)))
     => (match-lambda
          [(list name offset layout*)
           (list name offset (wrap (layout-abi layout) layout*))])]
    [else (error 'struct-layout-lookup "field ~a not found" field-name)]))

;; symbol * union-layout -> layout
(define (union-layout-lookup variant-name layout)
  (cond
    [(assq variant-name (internal:layout:union-variants (unwrap layout)))
     => (match-lambda
          [(list name layout*)
           (list name (wrap (layout-abi layout) layout*))])]
    [else (error 'union-layout-lookup "variant ~a not found" variant-name)]))

;; =============================================================================
;; COMPILATION MONAD
;; =============================================================================

;; (compilation x) = (listof uint) abi -> (values (listof uint) abi x)

;; x -> (compilation x)
(define (return x)
  (lambda (compiled abi)
    (values compiled abi x)))

;; (compilation x) * (x -> (compilation y)) -> (compilation y)
(define (bind c f)
  (lambda (compiled1 abi1)
    (let-values ([(compiled2 abi2 x) (c compiled1 abi1)])
      ((f x) compiled2 abi2))))

;; -> (compilation uint)
(define (pop)
  (lambda (compiled abi)
    (if (null? compiled)
        (error 'compile-header "unexpected end of compilation data")
        (values (cdr compiled) abi (car compiled)))))

;; symbol * internal:layout -> (compilation #f)
(define (save-typedef name type)
  (lambda (compiled abi)
    (values compiled (abi-add-typedef abi name type) #f)))

;; symbol * internal:layout -> (compilation #f)
(define (save-tag tag type)
  (lambda (compiled abi)
    (values compiled (abi-add-tag abi tag type) #f)))

(define-syntax do
  (syntax-rules (<-)
    [(do s) s]
    [(do (x <- e) s ...)
     (bind e (lambda (x)
               (do s ...)))]
    [(do s1 s ...)
     (bind s1 (lambda (_) (do s ...)))]))

;; header * compiler -> abi
(define (compile-header header compiler)
  (let-values ([(queries computation deps) (precompile-header header)])
    ;; XXX: extend the computation to query the dependencies too
    (let-values ([(queries-left abi _) (computation (compiler queries) (make-abi #hasheq() #hasheq()))])
      abi)))

;; =============================================================================
;; PRECOMPILATION MONAD
;; =============================================================================

;; (precompilation x y) = x -> (values (listof query) (compilation y) (setof symbol))

;; (precompilation header #f)
(define (precompile-header header)
  (if (header? header)
      (precompile-compound-header header)
      (precompile-decl header)))

;; (precompilation compound-header #f)
(define (precompile-compound-header ch)
  (precompile-map precompile-header (header-headers ch)))

;; (precompilation decl #f)
(define (precompile-decl decl)
  (match decl
    [(struct decl:type:def (root defs))
     (let-values ([(queries compile-type deps) (precompile-type root)])
       ;; XXX: generalize defs to (symbol * (type -> type))
       (values (cons (make-query:sizeof (caar defs)) queries)
               (do
                 (size <- (pop))
                 (type <- compile-type)
                 (save-typedef (caar defs) type))
               deps))]
    [(struct decl:type:tagged ((struct type:struct (tag fields))))
     (let-values ([(queries compile-fields deps) (precompile-map (precompile-struct-field tag) fields)])
       (values (cons (make-query:sizeof `(struct ,tag)) queries)
               (do
                 (size <- (pop))
                 (fields <- compile-fields)
                 (save-tag tag (make-internal:layout:struct size tag fields)))
               deps))]
    [(struct decl:type:tagged ((struct type:union (tag fields))))
     (let-values ([(queries compile-fields deps) (precompile-map precompile-union-field fields)])
       (values (cons (make-query:sizeof `(union ,tag)) queries)
               (do
                 (size <- (pop))
                 (fields <- compile-fields)
                 (save-tag tag (make-internal:layout:union size tag fields)))
               deps))]
    #;[(struct decl:type:tagged ((struct type:enum ())))
     #f]))

(define (precompile-optional-type type)
  (if type (precompile-type type) (values null (return #f) set:empty)))

(define (precompile-type type)
  (match type
    [(struct type:ref (name))
     (values (list (make-query:sizeof name))
             (do
               (size <- (pop))
               (return (make-internal:layout:ref size name)))
             (set:add name set:empty))]
    [(struct type:struct (#f fields))
     (let-values ([(queries compile-fields deps) (precompile-map (precompile-struct-field #f) fields)])
       (values (cons (make-query:sizeof `(struct ,@fields)) queries)
               (do
                 (size <- (pop))
                 (fields <- compile-fields)
                 (return (make-internal:layout:struct size #f fields)))
               deps))]
    [(struct type:struct (tag #f))
     (values (list (make-query:sizeof `(struct ,tag)))
             (do
               (size <- (pop))
               (return (make-internal:layout:struct size tag #f)))
             set:empty)]
    [(struct type:struct (tag fields))
     (let-values ([(queries compile-fields deps) (precompile-map (precompile-struct-field tag) fields)])
       (values (cons (make-query:sizeof `(struct ,tag)) queries)
               (do
                 (size <- (pop))
                 (fields <- compile-fields)
                 (save-tag tag (make-internal:layout:struct size tag fields))
                 (return (make-internal:layout:struct size tag fields)))
               deps))]
    [(struct type:union (#f fields))
     (let-values ([(queries compile-fields deps) (precompile-map precompile-union-field fields)])
       (values (cons (make-query:sizeof `(union ,@fields)) queries)
               (do
                 (size <- (pop))
                 (fields <- compile-fields)
                 (return (make-internal:layout:union size #f fields)))
               deps))]
    [(struct type:union (tag #f))
     (values (list (make-query:sizeof `(union ,tag)))
             (do
               (size <- (pop))
               (return (make-internal:layout:union size tag #f)))
             set:empty)]
    [(struct type:union (tag fields))
     (let-values ([(queries compile-fields deps) (precompile-map precompile-union-field fields)])
       (values (cons (make-query:sizeof `(union ,tag)) queries)
               (do
                 (size <- (pop))
                 (fields <- compile-fields)
                 (save-tag tag (make-internal:layout:union size tag fields))
                 (return (make-internal:layout:union size tag fields)))
               deps))]
    [(struct type:array (type length))
     (let-values ([(queries compile-type deps) (precompile-type type)])
       (values (cons (make-query:sizeof `(array ,type ,length))
                     (cons (make-query:expr length)
                           queries))
               (do
                 (size <- (pop))
                 (length <- (pop))
                 (type <- compile-type)
                 (return (make-internal:layout:array size type length)))
               deps))]
    [(struct type:pointer (type))
     #f]
    [(struct type:function (return args))
     #f]))

;; symbol -> (precompilation (cons symbol type) (list symbol uint layout))
(define ((precompile-struct-field tag) field)
  (match field
    [(cons name type)
     (let-values ([(queries compile-type deps) (precompile-optional-type type)])
       (values (cons (make-query:offset `(struct ,tag) name) queries)
               (do
                 (offset <- (pop))
                 (type <- compile-type)
                 (return (list name offset type)))
               deps))]))

;; (precompilation (cons symbol type) (list symbol layout))
(define (precompile-union-field field)
  (match field
    [(cons name type)
     (let-values ([(queries compile-type deps) (precompile-type type)])
       (values queries
               (do
                 (type <- compile-type)
                 (return (list name type)))
               deps))]))

;; (precompilation x y) * (listof x) -> (precompilation (listof x) (listof y))
(define (precompile-map f ls)
  (if (null? ls)
      (values null (return null) set:empty)
      (let-values ([(queries1 c1 deps1) (f (car ls))]
                   [(queries2 c2 deps2) (precompile-map f (cdr ls))])
        (values (append queries1 queries2)
                (bind c1 (lambda (x)
                           (bind c2 (lambda (y)
                                      (return (cons x y))))))
                (set:union deps1 deps2)))))

;; =============================================================================
;; SYSTEM HEADER COMPILER
;; =============================================================================

;; query -> void
(define (print-query query)
  (match query
    [(struct query:sizeof ((? symbol? type)))
     (printf "printf(\"%d ; (sizeof ~a)\\n\", sizeof(~a));\n" type type)]
    [(struct query:sizeof (`(struct ,tag)))
     (printf "printf(\"%d ; (sizeof (struct ~a))\\n\", sizeof(struct ~a));\n" tag tag)]
    [(struct query:sizeof ('pointer))
     (printf "printf(\"%d ; (sizeof pointer)\\n\", sizeof(void*));\n")]
    [(struct query:sizeof (type))
     (printf "printf(\"#f ; (sizeof ~a) -- *SKIPPED*\\n\");\n" type)]
    [(struct query:offset (`(struct ,tag) field-name))
     (printf "GET_OFFSET(struct ~a, ~a, off);\n" tag field-name)
     (printf "printf(\"%d ; (offset (struct ~a) ~a)\\n\", off);\n" tag field-name)]
    [(struct query:offset ((? symbol? type) field-name))
     (printf "GET_OFFSET(~a, ~a, off); \n" type field-name)
     (printf "printf(\"%d ; (offset ~a ~a)\\n\", off);\n" type field-name)]
    [(struct query:offset (type field-name))
     (printf "printf(\"#f ; (offset ~a ~a) -- *SKIPPED*\\n\");\n" type field-name)]
    #;[(struct query:expr (expr))
     (printf "printf(\"%d ; ~a\\n\", ~a);\n" expr (expression->C expr))]))

;; (listof query) [* (listof string)] -> void
(define (print-queries queries [includes null])
  (printf "#include <stdio.h>\n")
  (for ([include includes])
    (printf "#include ~a\n" include))
  
  (printf "#define GET_OFFSET(TYP, field, x) { TYP ___tmp___; x = (((int)(&___tmp___.field)) - ((int)(&___tmp___))); }\n")
  
  (printf "int main() {\n")
  (printf "int off;\n")
  
  (printf "printf(\"(\\n\");\n")
  (for-each print-query queries)
  (printf "printf(\")\\n\");\n")
  
  (printf "return 0;\n")
  (printf "}\n"))

#;(define (expression->C expr)
  (match expr
    [(? symbol?) (symbol->string expr)]
    [(struct expr:binop (op left right))
     (format "((~a) ~a (~a))" (expression->C left) op (expression->C right))]
    ;; TODO: this is grody and broken for chars, strings
    [(? expr:lit?)
     (format "~a" (expr:lit-value expr))]))

(define (system-compiler #:include<> [include<> null] #:include [include null] [exe gcc])
  (lambda (queries)
    (let ([headers (append (for/list ([i include<>])
                             (format "<~a>" i))
                           (for/list ([i include])
                             (format "\"~a\"" i)))])
      (let-values ([(stdout stderr) (exe (lambda ()
                                           (print-queries queries headers)))])
        (let ([result (read stdout)])
          result)))))

;; =============================================================================
;; TESTS (TEMPORARY)
;; =============================================================================

(define time.h
  (make-header
    #reader (planet dherman/c/reader) {
        struct tm {
            int tm_sec;
            int tm_min;
            int tm_hour;
            int tm_mday;
            int tm_mon;
            int tm_year;
            int tm_wday;
            int tm_yday;
            int tm_isdst;
        };
    }))

(define time
  (compile-header time.h (system-compiler #:include<> '("time.h"))))

(define windows.h
  (make-header
   #reader (planet dherman/c/reader) {
     typedef struct _COORD {
       SHORT X;
       SHORT Y;
     } COORD;

     typedef struct _SMALL_RECT {
       SHORT Left;
       SHORT Top;
       SHORT Right;
       SHORT Bottom;
     } SMALL_RECT;

     typedef struct _CONSOLE_SCREEN_BUFFER_INFO {
       COORD dwSize;
       COORD dwCursorPosition;
       WORD wAttributes;
       SMALL_RECT srWindow;
       COORD dwMaximumWindowSize;
     } CONSOLE_SCREEN_BUFFER_INFO;

     typedef struct _KEY_EVENT_RECORD {
       BOOL bKeyDown;
       WORD wRepeatCount;
       WORD wVirtualKeyCode;
       WORD wVirtualScanCode;
       union {
         WCHAR UnicodeChar;
         CHAR AsciiChar;
       } uChar;
       DWORD dwControlKeyState;
     } KEY_EVENT_RECORD;

     typedef struct _WINDOW_BUFFER_SIZE_RECORD {
       COORD dwSize;
     } WINDOW_BUFFER_SIZE_RECORD;

     typedef struct _INPUT_RECORD {
       WORD EventType;
       union {
         KEY_EVENT_RECORD KeyEvent;
         MOUSE_EVENT_RECORD MouseEvent;
         WINDOW_BUFFER_SIZE_RECORD WindowBufferSizeEvent;
         MENU_EVENT_RECORD MenuEvent;
         FOCUS_EVENT_RECORD FocusEvent;
       } Event;
     } INPUT_RECORD;
   }))

(define windows
  (compile-header windows.h (system-compiler #:include<> '("windows.h" "IpTypes.h") gcc)))