proxy.ss
(module proxy mzscheme
  
  (require (lib "thread.ss")
           (lib "unit.ss")
           (lib "plt-match.ss")
           (lib "port.ss")
           (lib "tcp-sig.ss" "net")
           (lib "tcp-unit.ss" "net")
           (lib "uri-codec.ss" "net")
           (lib "url.ss" "net")
           (lib "list.ss" "srfi" "1")
           (lib "time.ss" "srfi" "19")
           (lib "cut.ss" "srfi" "26")
           (lib "connection-manager.ss" "web-server" "private")
           (lib "request.ss" "web-server" "private")
           (lib "request-structs.ss" "web-server" "private"))

  (provide kill)
  
  ; Top level connection handling ----------------
  
  ;; run-proxy
  ;;     : (bytes bytes (hash-table-of bytes bytes) -> (U path #f))
  ;;       (U string #f)
  ;;       integer
  ;;       integer
  ;;       tcp^
  ;;    -> (-> void)
  (define (run-proxy request->path hostname port max-waiting tcp@)
    (define-values/invoke-unit tcp@ (import) (export tcp^))
    
    (define custodian (make-custodian))
    
    ;; handle-connection : input-port output-port -> any
    (define (handle-connection ip op)
      (define conn
        (new-connection 30 ip op (current-custodian) #f))
      (with-handlers ([exn:fail:network?
                       (lambda (e)
                         (kill-connection! conn)
                         (raise e))])
        (parameterize ([current-id (begin0 next-id (set! next-id (add1 next-id)))])
          (let connection-loop ()
            (debug "request: start")
            (dispatch conn)
            (if (connection-close? conn)
                (begin (kill-connection! conn)
                       (debug "connection closed"))
                (connection-loop))))))
    
    ; Request handling -----------------------------
    
    (define next-id 0)
    (define current-id (make-parameter #f))
    (define (debug . args)
      (printf "~a: " (current-id))
      (apply printf args)
      (newline))
    
    ;; dispatch : connection -> void
    (define (dispatch connection)
      (define method #f)
      (define url #f)
      (define major #f)
      (define minor #f)
      (define headers (make-hash-table 'equal))
      (define accumulator null)
      
      ;(define (get-remote-ports)
      ;  (let ([host+port (hash-table-get headers #"Host")])
      ;    (match (regexp-match #rx#"([^:]+):(.*)" host+port)
      ;      [(list _ host port)
      ;       (tcp-connect (bytes->string/utf-8 host) (string->number (bytes->string/utf-8 port)))]
      ;      [other
      ;       (tcp-connect (bytes->string/utf-8 host+port) 80)])))
      
      (define (get-remote-ports)
        (tcp-connect "webcache.cs.bham.ac.uk" 3128))
        
      (define (get-content-length)
        (let ([bytes (hash-table-get headers #"Content-Length" (lambda () #f))])
          (if bytes 
              (string->number (bytes->string/utf-8 bytes))
              #f)))
      
      (define (read-length num)
        (let ([bytes (read-bytes num (connection-i-port connection))])
          (debug "request: read bytes: ~a" bytes)
          bytes))
      
      (define (read-line)
        (let ([line (read-bytes-line (connection-i-port connection) 'return-linefeed)])
          (debug "request: read line: ~a" line)
          line))
      
      (define (read-line-and-accumulate)
        (let ([line (read-line)])
          (set! accumulator (cons line accumulator))
          line))
      
      (define (parse-status-line)
        (let ([status-line (read-line-and-accumulate)])
          (if (eof-object? status-line)
              (set-connection-close?! connection #t)
              (match (regexp-match #rx#"([A-Z]+) (.+) HTTP/([0-9]).([0-9])" status-line)
                [(list _ request-method request-url request-major request-minor)
                 (set! method request-method)
                 (set! url request-url)
                 (set! major request-major)
                 (set! minor request-minor)
                 (parse-headers)]))))
      
      (define (parse-headers)
        (let loop ()
          (let ([line (read-line-and-accumulate)])
            (cond [(eof-object? line) 
                   (finish-headers)]
                  [(bytes=? line #"")
                   (finish-headers)]
                  [else (match (regexp-match #rx#"([^:]+): (.*)" line)
                          [(list _ name value)
                           (hash-table-put! headers name value)
                           (loop)]
                          [other (loop)])]))))
      
      (define (finish-headers)
        (debug "Received request: ~a~n" url)
        (let ([local-path  (request->path method url headers)]
              [conn-header (hash-table-get headers #"Proxy-Connection" (lambda () #f))])
          (debug "request: proxy-connection: ~a" conn-header)
          (when (and conn-header (bytes=? conn-header #"close"))
                (set-connection-close?! connection #t))
          (if local-path
              (send-local-response connection local-path)
              (let-values ([(remote-input remote-output) (get-remote-ports)])
                (debug "remote ports: ~a ~a~n" remote-input remote-output)
                (for-each (lambda (line)
                            (write-bytes line remote-output)
                            (write-bytes #"\r\n" remote-output))
                          (reverse accumulator))
                (write-bytes #"Pragma: no-cache\r\n" remote-output)
                (write-bytes #"Cache-Control: no-cache\r\n" remote-output)
                (flush-output remote-output)
                (when (equal? method #"POST")
                  (handle-body remote-output))
                (handle-response connection remote-input)))))
      
      (define (handle-body remote-output)
        (let ([content-length (get-content-length)])
          (debug "request: content-length: ~a" content-length)
          (if content-length
              (write-bytes (read-length content-length) remote-output)
              (let loop ()
                (let ([line (read-line)])
                  (cond
                    [(eof-object? line) 
                     (void)]
                    ;[(bytes=? line #"")
                    ; (write-bytes line remote-output)
                    ; (write-bytes #"\r\n" remote-output)
                    ; (void)]
                    [else
                     (write-bytes line remote-output)
                     (write-bytes #"\r\n" remote-output)
                     (loop)])))))
        (flush-output remote-output))
      
      (parse-status-line))
    
    (define (handle-response connection remote-input)
      (define major #f)
      (define minor #f)
      (define code #f)
      (define message #f)
      (define headers (make-hash-table 'equal))
      (define accumulator null)
      
      (define (get-content-length)
        (let ([bytes (hash-table-get headers #"Content-Length" (lambda () #f))])
          (if bytes 
              (string->number (bytes->string/utf-8 bytes))
              #f)))
      
      (define (read-length num)
        (let ([bytes (read-bytes num remote-input)])
          (debug "response: read ~a bytes" num)
          bytes))
      
      (define (read-line)
        (let ([line (read-bytes-line remote-input 'return-linefeed)])
          (debug "response: read line: ~a" line)
          line))
      
      (define (read-line-and-accumulate)
        (let ([line (read-line)])
          (set! accumulator (cons line accumulator))
          line))
      
      (define (parse-status-line)
        (let ([status-line (read-line-and-accumulate)])
          (match (regexp-match #rx#"HTTP/([0-9]).([0-9]) ([0-9]+) (.*)" status-line)
            [(list _ response-major response-minor response-code response-message)
             (set! major response-major)
             (set! minor response-minor)
             (set! code response-code)
             (set! message response-message)])
          (parse-headers)))
      
      (define (parse-headers)
        (let loop ()
          (let ([line (read-line-and-accumulate)])
            (cond [(eof-object? line) 
                   (finish-headers)]
                  [(bytes=? line #"")
                   (finish-headers)]
                  [else (match (regexp-match #rx#"([^:]+): (.*)" line)
                          [(list _ name value)
                           (hash-table-put! headers name value)
                           (loop)]
                          [other (loop)])]))))
      
      (define (finish-headers)
        (let ([conn-header (hash-table-get headers #"Proxy-Connection" (lambda () #f))])
          (debug "response: proxy-connection: ~a" conn-header)
          (when (and conn-header (bytes=? conn-header #"close"))
                (set-connection-close?! connection #t))
          (for-each (lambda (line)
                      (write-bytes line (connection-o-port connection))
                      (write-bytes #"\r\n" (connection-o-port connection)))
                    (reverse accumulator))
          (flush-output (connection-o-port connection))
          (debug "response: finished headers")
          (handle-body)))
      
      (define (handle-body)
        (let ([content-length (get-content-length)])
          (debug "response: content-length: ~a" content-length)
          (if content-length
              (write-bytes (read-length content-length) (connection-o-port connection))
              (unless (member code '(#"304"))
                (let loop ()
                  (let ([line (read-line)])
                    (cond
                      [(eof-object? line) 
                       (void)]
                      [(bytes=? line #"")
                       (write-bytes line (connection-o-port connection))
                       (write-bytes #"\r\n" (connection-o-port connection))
                       (void)]
                      [else
                       (write-bytes line (connection-o-port connection))
                       (write-bytes #"\r\n" (connection-o-port connection))
                       (loop)])))))
          (flush-output (connection-o-port connection))
          (debug "response: finished body")))
      
      (parse-status-line))
    
    (define (send-local-response connection path)
      (let ([in     (open-input-file path)]
            [out    (connection-o-port connection)]
            [date   (date->string (time-tai->date (current-time time-tai)) "~a, ~d ~b ~Y ~H:~M:~S GMT")]
            [type   (match (regexp-match #rx"\\.(.+)$" (path->string path))
                      [(list _ extension)
                       (cond [(equal? extension "html") "text/html"]
                             [(equal? extension "js")   "text/javascript"]
                             [else                      "text/plain"])]
                      [other "text/plain"])]
            [length (file-size path)])
        (debug "local response")
        (write-bytes #"HTTP/1.1 200 Okay\r\n" out)
        (write-bytes (string->bytes/utf-8 (format "Date: ~a\r\n" date)) out)
        (write-bytes #"Server: Untyped testing proxy\r\n" out)
        (write-bytes (string->bytes/utf-8 (format "Last-Modified: ~a\r\n" date)) out)
        (write-bytes (string->bytes/utf-8 (format "Content-Type: ~a\r\n" type)) out)
        (write-bytes (string->bytes/utf-8 (format "Content-Length: ~a\r\n" length)) out)
        (write-bytes #"Via: 1.1 www2.sbcs.qmul.ac.uk\r\n" out)
        (write-bytes #"\r\n" out)
        (copy-port in out)
        (close-input-port in)))
        
    (parameterize ([current-custodian custodian])
      (run-server port
                  (cut handle-connection <> <>)
                  #f
                  (lambda (exn) #f)
                  (cut tcp-listen <> <> <> hostname)
                  tcp-close
                  tcp-accept
                  tcp-accept/enable-break)
      (cut custodian-shutdown-all custodian)))
  
  ; SBCS specific stuff --------------------------
  
  ;; script-url->path : bytes bytes (hash-table-of bytes bytes) -> (U path #f)
  (define (request->path method url headers)
    (cond [(regexp-match #rx#"stress.html$" url)
           (string->path "stress.html")]
          [(regexp-match #rx#"stress.js$" url)
           (string->path "stress.js")]
          [else #f]))
  
  (define kill 
    (thread
     (lambda ()
       (parameterize ([print-hash-table #t]
                      [print-struct #t])
         (run-proxy request->path "localhost" 7654 10 tcp@)))))
  
  )