sockets.ss
#lang scheme
(require
 srfi/26 openssl net/head "url.ss")

(define-struct ws-socket
  (url protocol in out))

(provide
 ws-socket?
 ws-socket-url ws-socket-protocol)

(define (ws-connect url [protocol #f] [ssl-context 'tls])
  (let ([secure? (ws-url-secure? url)]
        [host (url-host url)]
        [default-port (ws-url-default-port url)]
        [port (ws-url-port url)]
        [resource (ws-url-resource url)])
    (define (send-client-handshake out)
      (display
       (nest
        [(string-append (format "GET ~a HTTP/1.1\r\n" resource))
         (insert-field "Upgrade" "WebSocket")
         (insert-field "Connection" "Upgrade")
         (insert-field "Host" (if (= port default-port)
                                  host
                                  (format "~a:~a" host port)))
         (insert-field "Origin" "null")
         ((if protocol
              (cut insert-field "WebSocket-Protocol" protocol <>)
              values))]
        empty-header)
       out)
      (flush-output out))
    (define (validate-server-handshake in)
      (define (error template . arguments)
        (raise (make-exn:fail:network
                (apply format (string-append "~s: " template) 'ws-connect arguments)
                (current-continuation-marks))))
      (cond
        [(regexp-match #rx"^HTTP/1.1 ([0-9]+) ([^\r\n]*)\r\n" in)
         => (λ (groups)
              (let ([status (bytes->string/utf-8 (second groups) #\?)]
                    [message (bytes->string/utf-8 (third groups) #\?)])
                (unless (and (equal? status "101")
                             (equal? message "Web Socket Protocol Handshake"))
                  (error "server error: ~a (~a)" message status))))]
        [else
         (error "bad response header from server")])
      (cond
        [(regexp-match #rx"^(?:[^\r\n]+\r\n)*\r\n" in)
         => (λ (groups)
              (let ([header (bytes->string/utf-8 (first groups) #\?)])
                (validate-header header)
                (unless (and (equal? (extract-field "Connection" header)
                                     "Upgrade")
                             (equal? (extract-field "Upgrade" header)
                                     "WebSocket"))
                  (error "missing upgrade to WebSocket connection"))
                (unless (equal? (extract-field "WebSocket-Location" header)
                                (url->string
                                 (make-url
                                  (url-scheme url)
                                  #f
                                  host (and (not (= port default-port)) port)
                                  #t (url-path url) (url-query url) #f)))
                  (error "location from server does not match specified value"))
                (unless (or (not protocol)
                            (equal? (extract-field "WebSocket-Protocol" header)
                                    protocol))
                  (error "protocol from server does not match specified value"))
                #f))]
        [else
         (error "bad response header from server")]))
    (let-values ([(in out) (if secure?
                               (ssl-connect host port ssl-context)
                               (tcp-connect host port))])
      (call-with-exception-handler
       (λ (exn)
         (close-input-port in)
         (close-output-port out)
         exn)
       (λ ()
         (send-client-handshake out)
         (validate-server-handshake in)
         (make-ws-socket url protocol in out))))))

(provide/contract
 [ws-connect (->* (ws-url?)
                  ((or/c string? #f) (or/c ssl-client-context? symbol?))
                  ws-socket?)])

(define-struct ws-listener
  (ear secure? protocols))

(provide
 ws-listener?)

(define (ws-listen port [protocols '()] [ssl-context #f] [max-queue 4] [reuse? #f] [host #f])
  (make-ws-listener
   (if ssl-context
       (ssl-listen port max-queue reuse? host ssl-context)
       (tcp-listen port max-queue reuse? host))
   (if ssl-context #t #f)
   (make-immutable-hash (map (cut cons <> #t) protocols))))

(define (ws-accept listener)
  (let ([ear (ws-listener-ear listener)]
        [secure? (ws-listener-secure? listener)]
        [protocols (ws-listener-protocols listener)])
    (define (send-server-handshake out url origin protocol)
      (display
       (nest
        [(string-append "HTTP/1.1 101 Web Socket Protocol Handshake\r\n")
         (insert-field "Upgrade" "WebSocket")
         (insert-field "Connection" "Upgrade")
         (insert-field "WebSocket-Origin" origin)
         (insert-field "WebSocket-Location" (url->string url))
         ((if protocol
              (cut insert-field "WebSocket-Protocol" protocol <>)
              values))]
        empty-header)
       out)
      (flush-output out)
      (values url protocol))
    (define (validate-client-handshake in)
      (define (error template . arguments)
        (raise (make-exn:fail:network
                (apply format (string-append "~s: " template) 'ws-accept arguments)
                (current-continuation-marks))))
      (let* ([url (cond
                    [(regexp-match #rx"^GET ([^ \r\n]+) HTTP/1.1\r\n" in)
                     => (λ (groups)
                          (string->url (bytes->string/utf-8 (second groups) #\?)))]
                    [else
                     (error "bad request header from client")])]
             [header (cond
                       [(regexp-match #rx"^(?:[^\r\n]+\r\n)*\r\n" in)
                        => (λ (groups)
                             (bytes->string/utf-8 (first groups) #\?))]
                       [else
                        (error "bad request header from client")])])
        (validate-header header)
        (unless (and (equal? (extract-field "Connection" header)
                             "Upgrade")
                     (equal? (extract-field "Upgrade" header)
                             "WebSocket"))
          (error "missing upgrade to WebSocket connection"))
        (let ([host+port (or (regexp-match #rx"^([^:]+)(?::([0-9]+))?$"
                                           (or (extract-field "Host" header)
                                               (error "client specified no host")))
                             (error "client specified invalid host"))]
              [origin (or (extract-field "Origin" header)
                          (error "client specified no origin"))]
              [protocol (extract-field "WebSocket-Protocol" header)])
          (unless (or (not protocol) (hash-ref protocols protocol #f))
            (error "client requested unsupported protocol ~e" protocol))
          (values
           (make-url
            (if secure? "wss" "ws")
            #f
            (second host+port) (cond [(third host+port) => string->number] [else #f])
            #t (url-path url) (url-query url) #f)
           origin
           protocol))))
    (let-values ([(in out) ((if secure? ssl-accept tcp-accept) ear)])
      (call-with-exception-handler
       (λ (exn)
         (close-input-port in)
         (close-output-port out)
         exn)
       (compose
        (cut make-ws-socket <> <> in out)
        (cut send-server-handshake out <> <> <>)
        (cut validate-client-handshake in))))))

(provide/contract
 [ws-listen (->* ((and/c exact-nonnegative-integer? (integer-in 0 65535)))
                 ((listof string?) (or/c ssl-server-context? #f) exact-nonnegative-integer? any/c (or/c string? #f))
                 ws-listener?)]
 [ws-accept (-> ws-listener? ws-socket?)])

(define (ws-send* out frame)
  (cond
    [(bytes? frame)
     (write-byte #b10000000 out)
     (for-each
      (cut write-byte <> out)
      (let more-bits ([count (bytes-length frame)] [marker '()])
        (if (positive? count)
            (more-bits
             (arithmetic-shift count -7)
             (cons
              (bitwise-ior (if (null? marker) #b00000000 #b10000000)
                           (bitwise-bit-field count 0 7))
              marker))
            marker)))
     (write-bytes frame out)]
    [(string? frame)
     (write-byte #b00000000 out)
     (display frame out)
     (write-byte #b11111111 out)])
  (flush-output out))

(define (ws-send socket frame)
  (ws-send* (ws-socket-out socket) frame))

(define (ws-send-ready-evt socket)
  (wrap-evt (ws-socket-out socket) (λ (out) socket)))

(define (ws-send-evt socket frame)
  (wrap-evt (ws-socket-out socket) (cut ws-send* <> frame)))

(define (ws-receive* in)
  (define (error template . args)
    (raise (make-exn:fail:network
            (apply format (string-append "~s: " template) 'ws-receive args))))
  (let* ([frame-type (read-byte in)])
    (cond
      [(eof-object? frame-type)
       eof]
      [(bitwise-bit-set? frame-type 7)
       (unless (zero? (bitwise-bit-field frame-type 0 7))
         (error "unknown binary frame type ~a" frame-type))
       (read-bytes
        (let more-bits ([count 0])
          (define (update bits)
            (bitwise-ior (arithmetic-shift count 7) (bitwise-bit-field bits 0 7)))
          (let ([bits (read-byte in)])
            (if (bitwise-bit-set? bits 7)
                (more-bits (update bits))
                (update bits))))
        in)]
      [(not (zero? frame-type))
       (error "unknown textual frame type ~a" frame-type)]
      [(regexp-match #rx#"^([^\xff]*)\xff" in)
       => (compose bytes->string/utf-8 second)]
      [else
       (error "incoming frame seems to be broken")])))

(define (ws-receive socket)
  (ws-receive* (ws-socket-in socket)))

(define (ws-receive-ready-evt socket)
  (wrap-evt (ws-socket-in socket) (λ (in) socket)))

(define (ws-receive-evt socket)
  (wrap-evt (ws-socket-in socket) ws-receive*))

(provide/contract
 [ws-send (-> ws-socket? (or/c bytes? string?) any)]
 [ws-send-ready-evt (-> ws-socket? evt?)]
 [ws-send-evt (-> ws-socket? (or/c bytes? string?) evt?)]
 [ws-receive (-> ws-socket? (or/c bytes? string? eof-object?))]
 [ws-receive-ready-evt (-> ws-socket? evt?)]
 [ws-receive-evt (-> ws-socket? evt?)])

(define (ws-close ws)
  (cond
    [(ws-listener? ws)
     (let ([ear (ws-listener-ear ws)]
           [secure? (ws-listener-secure? ws)])
       ((if secure? ssl-close tcp-close) ear))]
    [(ws-socket? ws)
     (let ([in (ws-socket-in ws)] [out (ws-socket-out ws)])
       (close-input-port in)
       (close-output-port out))]))

(provide/contract
 [ws-close (-> (or/c ws-listener? ws-socket?) any)])