cipher.ss
;; mzcrypto: crypto library for mzscheme
;; Copyright (C) 2007 Dimitris Vyzovitis <vyzo@media.mit.edu>
;;
;; This library is free software; you can redistribute it and/or
;; modify it under the terms of the GNU Lesser General Public
;; License as published by the Free Software Foundation; either
;; version 2.1 of the License, or (at your option) any later version.
;;
;; This library is distributed in the hope that it will be useful,
;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
;; Lesser General Public License for more details.
;;
;; You should have received a copy of the GNU Lesser General Public
;; License along with this library; if not, write to the Free Software
;; Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
;; USA

(module cipher mzscheme
  (require-for-syntax "stx-util.ss")
  (require (lib "foreign.ss")
           (lib "kw.ss")
           (lib "plt-match.ss"))
  (require "libcrypto.ss" "error.ss" "rand.ss" "util.ss")
  (provide (all-defined))
  
  ;; libcrypto < 0.9.8.d doesn't have EVP_CIPHER_CTX_new/free
  (define-values (EVP_CIPHER_CTX_new EVP_CIPHER_CTX_free)
    (if (ffi-available? 'EVP_CIPHER_CTX_new)
        (values
          (lambda/ffi (EVP_CIPHER_CTX_new) -> _pointer : pointer/error)
          (lambda/ffi (EVP_CIPHER_CTX_free _pointer)))
        (values 
          (lambda () (make-bytes 192))
          (lambda/ffi (EVP_CIPHER_CTX_cleanup _pointer) -> _int : void))))

  (define/ffi (EVP_CipherInit_ex _pointer _pointer (_pointer = #f)
                                 _pointer _pointer _bool) 
    -> _int : check-error)
  (define/ffi (EVP_CipherUpdate _pointer _pointer 
                                (olen : (_ptr o _int)) _pointer _int)
    -> _int : (lambda (f r) (check-error f r) olen))
  (define/ffi (EVP_CipherFinal_ex _pointer _pointer (olen : (_ptr o _int)))
    -> _int : (lambda (f r) (check-error f r) olen))
  (define/ffi (EVP_CIPHER_CTX_set_padding _pointer _bool) 
    -> _int : check-error)

  ;; ivlen: #f when no iv (0 in the cipher)
  (define-struct cipher:algo (evp size keylen ivlen))
  (define-struct cipher (algo ctx olen encrypt?))

  (define (generate-cipher-key algo)
    (let ((klen (cipher:algo-keylen algo))
          (ivlen (cipher:algo-ivlen algo)))
    (values (random-bytes klen) (and ivlen (pseudo-random-bytes ivlen)))))

  (define (cipher-init algo key iv enc? pad?)
    (let/error-fini ((ctx (EVP_CIPHER_CTX_new) EVP_CIPHER_CTX_free))
      (EVP_CipherInit_ex ctx (cipher:algo-evp algo) key iv enc?)
      (EVP_CIPHER_CTX_set_padding ctx pad?)
      (let ((c (make-cipher algo ctx (cipher:algo-size algo) enc?)))
        (register-finalizer c
          (lambda (o) (cond ((cipher-ctx o) => EVP_CIPHER_CTX_free))))
        c)))

  ;; obs len >= olen + ilen
  (define (cipher-update c obs ibs ilen)
    (cond
     ((cipher-ctx c) =>
      (lambda (ctx) (EVP_CipherUpdate ctx obs ibs ilen)))
     (else (error 'cipher-update "finalized context"))))
  
  (define (cipher-final c obs)
    (cond
     ((cipher-ctx c) =>
      (lambda (ctx)
        (let ((olen (EVP_CipherFinal_ex ctx obs)))
          (EVP_CIPHER_CTX_free ctx)
          (set-cipher-ctx! c #f)
          olen)))
     (else (error 'cipher-final "finalized context"))))

  (define (cipher-new algo key iv enc? pad?)
    (unless (>= (bytes-length key) (cipher:algo-keylen algo))
      (error 'cipher-new "bad key"))
    (when (cipher:algo-ivlen algo)
      (unless (and iv (>= (bytes-length iv) (cipher:algo-ivlen algo)))
        (error 'cipher-new "bad iv")))
    (cipher-init algo key (if (cipher:algo-ivlen algo) iv #f) enc? pad?))

  (define (cipher-maxlen c ilen) 
    (+ ilen (cipher-olen c)))

  ;; api
  (define/kw (cipher-encrypt algo key iv #:key (pad? #:padding #t))
    (cipher-new algo key iv #t pad?))
  
  (define/kw (cipher-decrypt algo key iv #:key (pad? #:padding #t))
    (cipher-new algo key iv #f pad?))
  
  (define cipher-update!
    (case-lambda
      ((c ibs)
       (cipher-update! c ibs 
         (make-bytes (cipher-maxlen c (bytes-length ibs)))))
      ((c ibs obs)
       (check-output-range 'cipher-update obs 
         (cipher-maxlen c (bytes-length ibs)))
       (values obs (cipher-update c obs ibs (bytes-length ibs))))
      ((c ibs obs istart iend ostart oend)
       (check-input-range 'cipher-update ibs istart iend)
       (check-output-range 'cipher-update obs ostart oend 
         (cipher-maxlen c (- iend istart)))
       (values obs (cipher-update c 
                     (ptr-add obs ostart) 
                     (ptr-add ibs istart) (- iend istart))))))
  
  (define cipher-final!
    (case-lambda
      ((c)
       (cipher-final! c (make-bytes (cipher-olen c))))
      ((c obs)
       (check-output-range 'cipher-final obs (cipher-olen c))
       (values obs (cipher-final c obs)))
      ((c obs ostart oend)
       (check-output-range 'cipher-final obs ostart oend (cipher-olen c))
       (values obs (cipher-final c (ptr-add obs ostart))))))
  
  (define-syntax define-cipher-prop
    (syntax-rules ()
      ((_ prop op)
       (define (prop c)
         (cond
          ((cipher:algo? c) (op c))
          ((cipher? c) (op (cipher-algo c)))
          (else (raise-type-error 'prop "cipher or cipher algorithm" c)))))))
  
  (define-cipher-prop cipher-block-size cipher:algo-size)
  (define-cipher-prop cipher-key-length cipher:algo-keylen)
  (define-cipher-prop cipher-iv-length cipher:algo-ivlen)
  
  (define (cipher-port cipher inp outp)
    (let* ((1b (cipher-block-size cipher))
           (2bs (* 2 1b))
           (ibuf (make-bytes 1b))
           (obuf (make-bytes 2bs)))
      (let lp ((icount (read-bytes-avail! ibuf inp)))
        (if (eof-object? icount)
            (let-values (((bs ocount) 
                          (cipher-final! cipher obuf)))
              (void (write-bytes obuf outp 0 ocount)))
            (let-values (((bs ocount) 
                          (cipher-update! cipher ibuf obuf 0 icount 0 2bs)))
              (write-bytes obuf outp 0 ocount)
              (lp (read-bytes-avail! ibuf inp)))))))

  (define-syntax define/cipher-port
    (syntax-rules ()
      ((_ op init)
       (define op
         (case-lambda
           ((algo key iv)
            (let-values (((cipher) (init algo key iv))
                         ((rd1 wr1) (make-pipe))
                         ((rd2 wr2) (make-pipe)))
              (thread (lambda ()
                        (cipher-port cipher rd1 wr2)
                        (close-input-port rd1)
                        (close-output-port wr2)))
              (values rd2 wr1)))
           ((algo key iv inp)
            (cond 
             ((bytes? inp) 
              (let ((outp (open-output-bytes)))
                (cipher-port (init algo key iv) (open-input-bytes inp) outp)
                (get-output-bytes outp)))
             ((input-port? inp)
              (let-values (((cipher) (init algo key iv))
                           ((rd wr) (make-pipe)))
                (thread (lambda () 
                          (cipher-port cipher inp wr)
                          (close-output-port wr)))
                rd))
             (else (raise-type-error 'op "bytes or input-port" inp))))
           ((algo key iv inp outp)
            (unless (output-port? outp)
              (raise-type-error 'op "output-port" outp))
            (cond 
             ((bytes? inp)
              (cipher-port (init algo key iv) (open-input-bytes inp) outp))
             ((input-port? inp)
              (cipher-port (init algo key iv) inp outp))
             (else (raise-type-error 'op "bytes or input-port" inp)))))))))

  (define/cipher-port encrypt cipher-encrypt)
  (define/cipher-port decrypt cipher-decrypt)

  ;; EVP_CIPHER: struct evp_cipher_st {nid block_size key_len iv_len ...}
  (define (c->props evp)
    (match (ptr-ref evp (_list-struct _int _int _int _int))
      ((list _ size keylen ivlen)
       (values size keylen (and (> ivlen 0) ivlen)))))

  (define *ciphers* null)
  (define (available-ciphers) *ciphers*)

  (define-for-syntax cipher-modes '(ecb cbc cfb ofb))
  (define-for-syntax default-mode 'cbc)
   
  (define-syntax (define-cipher stx)
    (define (unhyphen what) (regexp-replace* "-" what "_"))
    (define (make-cipher mode)
      (with-syntax
          ((evp (->stx stx (make-symbol "EVP_" (unhyphen mode))))
           (cipher (->stx stx (make-symbol "cipher:" mode))))
        #`(define cipher
            (if (ffi-available? 'evp)
                (let ((evpp ((lambda/ffi (evp) -> _pointer))))
                  (call/values 
                    (lambda () (c->props evpp))
                    (lambda (size keylen ivlen)
                      (make-cipher:algo evpp size keylen ivlen))))
                #f))))
    
    (define (make name)
      (with-syntax
          ((cipher (->stx stx (make-symbol "cipher:" name)))
           (alias (->stx stx (make-symbol "cipher:" name "-" default-mode)))
           (provider (->stx stx (make-symbol "provide:cipher:" name))))
        (let ((modes (map (lambda (mode) (format "~a-~a" name mode)) 
                          cipher-modes)))
          #`(begin
              #,@(map make-cipher modes)
              (define cipher 
                (begin (when alias (push! *ciphers* '#,(make-symbol name)))
                       alias))
              (define-syntax provider
                (syntax-rules ()
                  ((_) 
                   (provide cipher
                     #,@(map (lambda (mode) 
                               (->stx stx (make-symbol "cipher:" mode))) 
                             modes)))))))))

    (define (make-meta-provider name ks)
      (let* ((base (make-symbol "provide:cipher:" name))
             (provs (map (lambda (k) (make-symbol base "-" k)) ks)))
        (with-syntax ((provider (->stx stx base)))
          #`(define-syntax provider
              (syntax-rules ()
                ((_)
                 (begin #,@(map (lambda (p) (->stx stx (list p))) provs))))))))
        
     (syntax-case stx ()
       ((_ c) 
        (make (->datum #'c)))
       ((_ c (klen ...))
        (let ((name (->string (->datum #'c)))
              (ks (map (lambda (k) (->datum k)) (syntax->list #'(klen ...)))))
          #`(begin
              #,@(map (lambda (k) (make (format "~a-~a" name k))) ks)
              #,(make-meta-provider name ks))))))

   (define-cipher des)
   (define-cipher des-ede)
   (define-cipher des-ede3)
   (define-cipher idea)
   (define-cipher bf)
   (define-cipher cast5)
   (define-cipher aes (128 192 256))
   (define-cipher camellia (128 192 256))
  
   (define-syntax provide:cipher
     (syntax-rules ()
       ((_)
        (begin
          (provide available-ciphers cipher? cipher-encrypt?
                   cipher-block-size cipher-key-length cipher-iv-length
                   cipher-encrypt cipher-decrypt cipher-update! cipher-final!
                   encrypt decrypt)
          (provide:cipher:des)
          (provide:cipher:des-ede)
          (provide:cipher:des-ede3)
          (provide:cipher:idea)
          (provide:cipher:bf)
          (provide:cipher:cast5)
          (provide:cipher:aes)
          (provide:cipher:camellia)))))
   
)