flash-domain-policy-server.rkt
#lang racket/base

#|

#  rkt-flash-domain-policy
#  Flash Cross Domain Policy Server
#  License: MIT
#
#  This is a simple implementation of a Flash
#  cross-domain policy server written in Racket.
#
#  How to use it:
#
#  (run-flash-domain-policy-server aaf ...)
#  Where each aaf is a list:
#    '(from-domain to-ports)
#    '(from-domain to-ports secure)
#  Optional keyword arguments:
#  #:notify-proc proc
#  Where proc is procedure with two arguments:
#    (lambda (a b) (printf "~s ~s" a b))
#  #:site-control ctrl where ctrl is one of:
#    'none
#    'master-only
#    'by-content-type
#    'by-ftp-filename
#    'all           
#  #:http-reqs (list httpreqs ...)
#  Where each httpreqs is a list:
#    '(from-domain headers)
#    '(from-domain headers secure)
#  #:identities (list certs ...)
#  Where each certs is a list:
#    '(signature signature-algorithm)
#  Parameters for tcp-listen:
#  #:max-allow-wait number
#  #:hostname string
#  #:port port-number
#
# Examples:
#
#  Runs a server that allow access from any domain (*)
#  to port 2000, with the "master-only" site control:
#
#    (run-flash-domain-policy-server
#      #:site-control 'master-only
#      #:notify-proc (lambda (a b) (printf "~s ~s" a b))
#      '("*" 2000))
#
#  Runs a server that allow access from *.example.com
#  to three different ports:
#
#    (run-flash-domain-policy-server '("*.example.com" (123 234 345)))
#
#  Runs a server with everything:
#
#    (run-flash-domain-policy-server
#     #:notify-proc (lambda (a b) 
#                     (printf "~a~n" b))
#     #:site-control 'master-only
#     #:http-reqs (list '("*" "Jack" #t))
#     #:identities (list '("ABC" "DEF"))
#     '("*" 2000 #t)
#     '("*.example.com" (123 455 200) #t)))
#

|#

(require racket/tcp
         racket/list)

(provide run-flash-domain-policy-server
         (struct-out flash-domain-policy-server-event-connection)
         (struct-out flash-domain-policy-server-event-sending)
         (struct-out flash-domain-policy-server-event-closing)
         (struct-out flash-domain-policy-server-event-timeout))

(struct flash-domain-policy-server-event-connection
  (remote-host
   local-host
   )
  #:transparent)
(struct flash-domain-policy-server-event-sending
  ()
  #:transparent)
(struct flash-domain-policy-server-event-closing
  ()
  #:transparent)
(struct flash-domain-policy-server-event-timeout
  ()
  #:transparent)

(define (run-flash-domain-policy-server
         #:notify-proc [notproc void]
         #:site-control [sitectrl #f]
         #:http-reqs [httpreqs #f]
         #:identities [certs #f]
         #:max-allow-wait [maxallow 4]
         #:hostname [hostname #f]
         #:port [port 843]
         . config)
  (define acc (tcp-listen port maxallow #f hostname))
  (let loop ()
    (define cust (make-custodian))
    (define done #f)
    (parameterize ([current-custodian cust])
      (let*-values ([(in out) (tcp-accept acc)]
                    [(here there) (tcp-addresses in)])
        (notproc (flash-domain-policy-server-event-connection there here)
                 (format "Connection from ~s" there))
        (thread
         (lambda ()
           (when (regexp-match
                  #rx"^<policy-file-request"
                  (bytes->string/utf-8
                   (apply bytes 
                          (reverse (let loop ([b (read-byte in)] [cur '()])
                                     (if (or (eq? b 0) (eof-object? b))
                                         cur
                                         (loop (read-byte in) (cons b cur))))))))
             (notproc (flash-domain-policy-server-event-sending)
                      "Sending policy file")
             
             
             ; Reply with the cross-domain policy file
             (display 
              (string-append
               "<?xml version=\"1.0\"?>"
               "<!DOCTYPE cross-domain-policy SYSTEM \"/xml/dtds/cross-domain-policy.dtd\">"
               "<cross-domain-policy>"
               (if sitectrl
                   (format "<site-control permitted-cross-domain-policies=\"~a\"/>" sitectrl)
                   "")
               
               (apply
                string-append
                (for/list ([domtoports config])
                  (let ([dom     (car domtoports)]
                        [toports (second domtoports)])
                    (format "<allow-access-from domain=~s headers=\"~a\" ~a/>" 
                            dom (cond
                                  [(pair? toports)
                                   (apply string-append (add-between (map number->string toports) ","))]
                                  [(number? toports)
                                   (number->string toports)])
                            (if (> (length domtoports) 2)
                                (if (third domtoports)
                                    "secure=\"true\" "
                                    "secure=\"false\" ")
                                "")))))
               
               (if httpreqs
                   (apply
                    string-append
                    (for/list ([domtoports httpreqs])
                      (let ([dom  (car domtoports)]
                            [hdrs (second domtoports)])
                        (format "<allow-http-request-headers-from domain=~s to-ports=\"~a\" ~a/>" 
                                dom (cond
                                      [(pair? hdrs)
                                       (apply string-append (add-between hdrs ","))]
                                      [(string? hdrs)
                                       hdrs])
                                (if (> (length domtoports) 2)
                                    (if (third domtoports)
                                        "secure=\"true\" "
                                        "secure=\"false\" ")
                                    "")))))
                   "")
               
               (if certs
                   (apply
                    string-append
                    (for/list ([fing-and-alt certs])
                      (string-append
                       "<allow-access-from-identity><signatory>"
                       (let ([fingerprints     (car fing-and-alt)]
                             [fingerprints-alg (second fing-and-alt)])
                         (format "<certificate fingerprint=~s fingerprint-algorithm=~s/>"
                                 fingerprints fingerprints-alg))
                       "</signatory></allow-access-from-identity>")))
                   "")
               
               "</cross-domain-policy>") out))
           (notproc (flash-domain-policy-server-event-closing)
                    "Closing connection")
           (close-input-port in)
           (close-output-port out)
           (set! done #t) )))
      
      ; Timeout after 3 seconds
      (thread
       (lambda ()
         (sleep 3)
         (unless done
           (notproc (flash-domain-policy-server-event-timeout)
                    "Connection timeout"))
         (custodian-shutdown-all cust) )))
    
    ; Loop to wait for the next incoming connection
    (loop))
  )

#|
(module+ test
  (run-flash-domain-policy-server
   #:notify-proc (lambda (a b) 
                   (printf "~a~n" b))
   #:site-control 'master-only
   #:http-reqs (list '("*" "Jack" #t))
   #:identities (list '("ABC" "DEF"))
   '("*" 2000 #t)
   '("*.example.com" (123 455 200) #t))
  )
|#
;;
;; Submodule for running this from the command line.
;;
(module+ main
  (let ([notelist void]
        [sitectrl #f]
        [httpreqs #f]
        [identsct #f]
        [maxwait  4]
        [hostname #f]
        [hostport 843]
        [sockets '()])
    (let loop ([lst (vector->list (current-command-line-arguments))])
      (cond
        [(eq? '() lst) (void)]
        [(equal? "--notify-proc" (car lst))
         (set! notelist 
               (cons (read (open-input-string (second lst)))
                     (if (eq? #f notelist) '() notelist)))
         (loop (cddr lst))]
        [(equal? "--site-control" (car lst))
         (set! sitectrl
               (cons (read (open-input-string (second lst)))
                     (if (eq? #f sitectrl) '() sitectrl)))
         (loop (cddr lst))]
        [(equal? "--http-reqs" (car lst))
         (set! httpreqs
               (cons (read (open-input-string (second lst)))
                     (if (eq? #f httpreqs) '() httpreqs)))
         (loop (cddr lst))]
        [(equal? "--identities" (car lst))
         (set! identsct 
               (cons (read (open-input-string (second lst)))
                     (if (eq? #f identsct) '() identsct)))
         (loop (cddr lst))]
        [(equal? "--max-allow-wait" (car lst))
         (set! maxwait (string->number (second lst)))
         (loop (cddr lst))]
        [(equal? "--hostname" (car lst))
         (set! hostname (second lst))
         (loop (cddr lst))]
        [(equal? "--port" (car lst))
         (set! hostport (string->number (second lst)))
         (loop (cddr lst))]
        [else
         (set! sockets 
               (cons (read (open-input-string (car lst)))
                     sockets))
         (loop (cddr lst))]))
    
    (if (and (eq? sitectrl #f)
            (eq? httpreqs #f)
            (eq? identsct #f)
            (eq? maxwait  4)
            (eq? hostname #f)
            (eq? hostport 843)
            (equal? sockets '()))
        (let ()
          (printf "Flash Domain Policy Server~n")
          (printf "Usage:~n")
          (printf "flash-domain-policy-server [flags ...] \"(from-domain to-port [#t])\" ...~n")
          (printf "Where from-domain is a string, to-port is a number and #t specifies secure~n")
          (printf "Each set of rules should be in quotes and parenthesis~n")
          (printf "Flags:~n")
          (printf "  --site-control none|master-only|by-content-type|by-ftp-filename|all~n")
          (printf "  --http-reqs (from-domain http-header [#t])~n")
          (printf "  --identities (signature signature-algorithm)~n")
          (printf "  --max-allow-wait number of pending connections (def. 4)~n")
          (printf "  --hostname hostname~n")
          (printf "  --port n overrides port 843~n")
          )
        (let ()
          (apply run-flash-domain-policy-server
                 #:notify-proc notelist
                 #:site-control sitectrl
                 #:http-reqs httpreqs
                 #:identities identsct
                 #:max-allow-wait maxwait
                 #:hostname hostname
                 #:port hostport
                 sockets))
        ))
  )