pkey.ss
;; mzcrypto: crypto library for mzscheme
;; Copyright (C) 2007 Dimitris Vyzovitis <vyzo@media.mit.edu>
;;
;; This library is free software; you can redistribute it and/or
;; modify it under the terms of the GNU Lesser General Public
;; License as published by the Free Software Foundation; either
;; version 2.1 of the License, or (at your option) any later version.
;;
;; This library is distributed in the hope that it will be useful,
;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
;; Lesser General Public License for more details.
;;
;; You should have received a copy of the GNU Lesser General Public
;; License along with this library; if not, write to the Free Software
;; Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
;; USA

(module pkey mzscheme
  (require-for-syntax "stx-util.ss")
  (require (lib "foreign.ss")
           (lib "plt-match.ss")
           (lib "kw.ss")
           (only (lib "etc.ss") compose)
           (lib "and-let.ss" "srfi" "2"))
  (require "libcrypto.ss" "error.ss" "util.ss" "digest.ss" "cipher.ss" "bn.ss")
  (provide (all-defined))
  
  (define/alloc EVP_PKEY)
  (define/alloc RSA)
  (define/alloc DSA)
  
  (define/ffi (EVP_PKEY_type _int) -> _int : int/error)
  (define/ffi (EVP_PKEY_size _pointer) -> _int : int/error)
  (define/ffi (EVP_PKEY_bits  _pointer) -> _int : int/error)
  (define/ffi (EVP_PKEY_assign _pointer _int _pointer) -> _int : check-error)
  (define/ffi (EVP_PKEY_set1_RSA _pointer _pointer) -> _int : check-error)
  (define/ffi (EVP_PKEY_set1_DSA _pointer _pointer) -> _int : check-error)
  (define/ffi 
    (EVP_SignFinal _pointer _pointer (count : (_ptr o _uint)) _pointer)
    -> _int : (lambda (f r) (check-error f r) count))
  (define/ffi (EVP_VerifyFinal _pointer _pointer _uint _pointer)
    -> _int : bool/error)
  (define/ffi (EVP_PKEY_cmp _pointer _pointer) -> _int : bool/error)
  (define/ffi (EVP_PKEY_encrypt _pointer _pointer _int _pointer)
    -> _int : int/error*)
  (define/ffi (EVP_PKEY_decrypt _pointer _pointer _int _pointer)
    -> _int : int/error*)
  
  (define/ffi (RSA_generate_key_ex _pointer _int _pointer (_pointer = #f))
    -> _int : check-error)
  (define/ffi 
    (DSA_generate_parameters_ex _pointer _int 
       (_pointer = #f) (_int = 0) (_pointer = #f) (_pointer = #f) 
       (_pointer = #f))
    -> _int : check-error)
  (define/ffi (DSA_generate_key _pointer) -> _int : check-error)

  (define-struct pkey:algo (type keygen))
  (define-struct pkey (algo evp private?))
  
  (define/ffi (d2i_PublicKey _int (_pointer = #f) (_ptr i _pointer) _long)
    -> _pointer : pointer/error)
  (define/ffi (d2i_PrivateKey _int (_pointer = #f) (_ptr i _pointer) _long)
    -> _pointer : pointer/error)

  (define/ffi (i2d_PublicKey _pointer (_ptr i _pointer)) -> _int : int/error)
  (define/ffi (i2d_PrivateKey _pointer (_ptr i _pointer)) -> _int : int/error)

  (define i2d_PublicKey-length
    (lambda/ffi (i2d_PublicKey _pointer (_pointer = #f)) 
       -> _int : int/error))
  (define i2d_PrivateKey-length
    (lambda/ffi (i2d_PrivateKey _pointer (_pointer = #f)) 
       -> _int : int/error))
  
  (define (pkey-size pk)
    (EVP_PKEY_size (pkey-evp pk)))
  
  (define (pkey-bits pk)
    (EVP_PKEY_bits (pkey-evp pk)))

  (define (pkey=? k1 k2 . klst)
    (let ((evp (pkey-evp k1)))
      (let lp ((lst (cons k2 klst)))
        (cond
         ((null? lst) #t)
         ((EVP_PKEY_cmp evp (pkey-evp (car lst))) (lp (cdr lst)))
         (else #f)))))

  (define (read-pkey algo public? ibs ilen)
    (let* ((d2i (if public? d2i_PublicKey d2i_PrivateKey))
           (evp (d2i (pkey:algo-type algo) ibs ilen))
           (pk (make-pkey algo evp (not public?))))
      (register-finalizer pk (compose EVP_PKEY_free pkey-evp))
      pk))
  
  (define (write-pkey pk public? . args)
    (let*-values 
        (((i2d i2d-len) 
          (if public? 
              (values i2d_PublicKey i2d_PublicKey-length)
              (values i2d_PrivateKey i2d_PrivateKey-length)))
         ((obs)
          (match args
            ((list obs olen)
             (if (< olen (i2d-len (pkey-evp pk)))
                 (error 'write-pkey "buffer too small")
                 obs))
            ((list)
             (make-bytes (i2d-len (pkey-evp pk)))))))
      (values obs (i2d (pkey-evp pk) obs))))

  (define-syntax define-bytes->pkey
    (syntax-rules ()
      ((_ op public?)
       (define (op algo bs) 
         (read-pkey algo public? bs (bytes-length bs))))))
  
  (define-syntax define-pkey->bytes
    (syntax-rules ()
      ((_ op public?)
       (define (op pk)
         (call/values 
           (lambda () (write-pkey pk public?)) 
           shrink-bytes)))))
  
  (define-pkey->bytes private-key->bytes #f)
  (define-bytes->pkey bytes->private-key #f)
  (define-pkey->bytes public-key->bytes #t)
  (define-bytes->pkey bytes->public-key #t)
  (define (pkey->public-key pk)
    (if (pkey-private? pk)
        (bytes->public-key (pkey-algo pk) (public-key->bytes pk))
        pk))

  ;; libcrypto #defines for those are autogened...
  ;; EVP_PKEY: struct evp_pkey_st {type ...}
  (define (pk->type evp)
    (EVP_PKEY_type (car (ptr-ref evp (_list-struct _int)))))

  (define (evp->pkey evp pkt pkp)
    (EVP_PKEY_assign evp (pkey:algo-type pkt) pkp)
    (let ((pk (make-pkey pkt evp #t)))
      (register-finalizer pk (compose EVP_PKEY_free pkey-evp)) ; auto-frees pkp
      pk))

  (define/kw (rsa-keygen bits #:key (exponent 65537))
    (let/fini ((ep (BN_new) BN_free))
      (BN_add_word ep exponent)
      (let/error-fini ((rsap (RSA_new) RSA_free)
                       (evp (EVP_PKEY_new) EVP_PKEY_free))
        (RSA_generate_key_ex rsap bits ep)
        (evp->pkey evp pkey:rsa rsap))))
 
  (define pkey:rsa
    (with-handlers* ((exn:fail? (lambda x #f)))
      (let/fini ((rsap (RSA_new) RSA_free)
                 (evp (EVP_PKEY_new) EVP_PKEY_free))
        (EVP_PKEY_set1_RSA evp rsap)
        (make-pkey:algo (pk->type evp) rsa-keygen))))

  (define (dsa-keygen bits)
    (let/error-fini ((dsap (DSA_new) DSA_free)
                     (evp (EVP_PKEY_new) EVP_PKEY_free))
      (DSA_generate_parameters_ex dsap bits)
      (DSA_generate_key dsap)
      (evp->pkey evp pkey:dsa dsap)))

  (define pkey:dsa
    (with-handlers* ((exn:fail? (lambda x #f)))
      (let/fini ((dsap (DSA_new) DSA_free)
                 (evp (EVP_PKEY_new) EVP_PKEY_free))
        (EVP_PKEY_set1_DSA evp dsap)
        (make-pkey:algo (pk->type evp) dsa-keygen))))

  (define (generate-pkey algo bits . args)
    (apply (pkey:algo-keygen algo) bits args))

  (define (pkey-sign dg pk bs)
    (unless (pkey-private? pk)
      (error 'sign "not a private key"))
    (cond
     ((digest-ctx dg) =>
      (lambda (ctx)
        (EVP_SignFinal ctx bs (pkey-evp pk))))
     (else (error 'pkey-sign "finalized context"))))

  (define (pkey-verify dg pk bs len)
    (cond
     ((digest-ctx dg) =>
      (lambda (ctx)
        (EVP_VerifyFinal ctx bs len (pkey-evp pk))))
     (else (error 'pkey-verify "finalized context"))))

  (define digest-sign
    (case-lambda
      ((dg pk)
       (digest-sign dg pk (make-bytes (pkey-size pk))))
      ((dg pk bs)
       (check-output-range 'digest-sign bs (pkey-size pk))
       (values bs (pkey-sign dg pk bs)))
      ((dg pk bs start end)
       (check-output-range 'digest-sign bs start end (pkey-size pk))
       (values bs (pkey-sign dg pk (ptr-add bs start))))))
  
  (define digest-verify
    (case-lambda
      ((dg pk bs)
       (pkey-verify dg pk bs (bytes-length bs)))
      ((dg pk bs start end)
       (check-input-range 'digest-verify bs start end)
       (pkey-verify dg pk (ptr-add bs start) (- end start)))))
  
  (define (sign-bytes dgalgo pk bs)
    (let ((dg (digest-new dgalgo)))
      (digest-update! dg bs)
      (call/values 
        (lambda () (digest-sign dg pk))
        shrink-bytes)))
  
  (define (verify-bytes dgalgo pk sigbs bs)
    (let ((dg (digest-new dgalgo)))
      (digest-update! dg bs)
      (digest-verify dg pk sigbs)))

  (define (sign-port dgalgo pk inp)
    (let ((dg (port->digest dgalgo inp)))
      (call/values 
        (lambda () (digest-sign dg pk))
        shrink-bytes)))
  
  (define (verify-port dgalgo pk sigbs inp)
    (digest-verify (port->digest dgalgo inp) pk sigbs))
  
  (define (sign pk dgalgo inp)
    (unless (pkey-digest? pk dgalgo)
      (error 'sign "incompatible digest type"))
    (cond 
     ((bytes? inp) (sign-bytes dgalgo pk inp))
     ((input-port? inp) (sign-port dgalgo pk inp))
     (else (raise-type-error 'sign "bytes or input-port" inp))))

  (define (verify pk dgalgo sigbs inp)
    (unless (pkey-digest? pk dgalgo)
      (error 'verify "incompatible digest type"))
    (cond 
     ((bytes? inp) (verify-bytes dgalgo pk sigbs inp))
     ((input-port? inp) (verify-port dgalgo pk sigbs inp))
     (else (raise-type-error 'verify "bytes or input-port" inp))))

  (define-syntax define-pkey-crypt
    (syntax-rules ()
      ((_ crypt op evp-op public?)
       (begin
         (define (op pk ibs ilen)
           (unless (or public? (pkey-private? pk))
             (error 'crypt "not a private key"))
           (let* ((obs (make-bytes (pkey-size pk)))
                  (olen (evp-op obs ibs ilen (pkey-evp pk))))
             (shrink-bytes obs olen)))
         (define crypt
           (case-lambda 
             ((pk ibs)
              (check-input-range 'crypt ibs (pkey-size pk))
              (op pk ibs (bytes-length ibs)))
             ((pk ibs istart iend)
              (check-input-range 'crypt ibs istart iend (pkey-size pk))
              (op pk (ptr-add ibs istart) (- iend istart)))))))))
  
  (define-pkey-crypt encrypt/pkey pkey-encrypt EVP_PKEY_encrypt #t)
  (define-pkey-crypt decrypt/pkey pkey-decrypt EVP_PKEY_decrypt #f)
  
  (define (encrypt/envelope pk cipher . cargs)
    (let*-values (((k iv) (generate-cipher-key cipher))
                  ((sk) (encrypt/pkey pk k)))
      (call/values
        (lambda () (apply encrypt cipher k iv cargs))
        (lambda cvals (apply values sk iv cvals)))))
  
  ;; sk sealed key
  (define (decrypt/envelope pk cipher sk iv  . cargs)
    (apply decrypt cipher (decrypt/pkey pk sk) iv cargs))

  (define *pkey-digests* null)
  (define (pkey-digest? pk dgalgo)
    (cond 
     ((pkey? pk) (pkey-digest? (pkey-algo pk) dgalgo))
     ((pkey:algo? pk)
      (and-let* ((ds (assq pk *pkey-digests*)))
        (and (memq dgalgo (cdr ds)) #t)))
     (else (raise-type-error 'pkey-digest? "pkey or pkey algorithm" pk))))

  (define-syntax (define-pkey stx)
    (define (make-digests dgs)
      (let ((dgts (map (lambda (dg) (make-symbol "digest:" dg)) dgs)))
        (->stx stx (list* 'list dgts))))
    
    (syntax-case stx ()
      ((_ pk (dg ...))
       (let ((pkt (->datum #'pk))
             (dgts (map ->datum (syntax->list #'(dg ...)))))
         (with-syntax
             ((algo (->stx stx (make-symbol "pkey:" pkt)))
              (digests (->stx stx (make-symbol "pkey:" pkt ":digests")))
              (provider (->stx stx (make-symbol "provide:pkey:" pkt))))
           #`(begin
               (define digests #,(make-digests dgts))
               (define-syntax provider
                 (syntax-rules ()
                   ((_) (provide algo digests))))
               (push! *pkey-digests* (cons algo digests))))))))
  
  (define-pkey rsa (ripemd160 sha1 sha224 sha256 sha384 sha512))
  ;; ugh - libcrypto-0.9.8 only supports dss/dss1 with dsa
  ;;       ...until 0.9.9
  (define-pkey dsa (dss1))
  
  (define-syntax provide:pkey
    (syntax-rules ()
      ((_)
       (begin
         (provide pkey? pkey-private? pkey-size pkey-bits pkey=? pkey-digest?
                  pkey->public-key public-key->bytes bytes->public-key
                  private-key->bytes bytes->private-key
                  digest-sign digest-verify
                  sign verify
                  encrypt/pkey decrypt/pkey
                  encrypt/envelope decrypt/envelope)
         (provide:pkey:rsa)
         (provide:pkey:dsa)))))
  
)