(module pkey mzscheme
(require-for-syntax "stx-util.ss")
(require (lib "foreign.ss")
(lib "plt-match.ss")
(lib "kw.ss")
(only (lib "etc.ss") compose)
(lib "and-let.ss" "srfi" "2"))
(require "libcrypto.ss" "error.ss" "util.ss" "digest.ss" "cipher.ss" "bn.ss")
(provide (all-defined))
(define/alloc EVP_PKEY)
(define/alloc RSA)
(define/alloc DSA)
(define/ffi (EVP_PKEY_type _int) -> _int : int/error)
(define/ffi (EVP_PKEY_size _pointer) -> _int : int/error)
(define/ffi (EVP_PKEY_bits _pointer) -> _int : int/error)
(define/ffi (EVP_PKEY_assign _pointer _int _pointer) -> _int : check-error)
(define/ffi (EVP_PKEY_set1_RSA _pointer _pointer) -> _int : check-error)
(define/ffi (EVP_PKEY_set1_DSA _pointer _pointer) -> _int : check-error)
(define/ffi
(EVP_SignFinal _pointer _pointer (count : (_ptr o _uint)) _pointer)
-> _int : (lambda (f r) (check-error f r) count))
(define/ffi (EVP_VerifyFinal _pointer _pointer _uint _pointer)
-> _int : bool/error)
(define/ffi (EVP_PKEY_cmp _pointer _pointer) -> _int : bool/error)
(define/ffi (EVP_PKEY_encrypt _pointer _pointer _int _pointer)
-> _int : int/error*)
(define/ffi (EVP_PKEY_decrypt _pointer _pointer _int _pointer)
-> _int : int/error*)
(define/ffi (RSA_generate_key_ex _pointer _int _pointer (_pointer = #f))
-> _int : check-error)
(define/ffi
(DSA_generate_parameters_ex _pointer _int
(_pointer = #f) (_int = 0) (_pointer = #f) (_pointer = #f)
(_pointer = #f))
-> _int : check-error)
(define/ffi (DSA_generate_key _pointer) -> _int : check-error)
(define-struct pkey:algo (type keygen))
(define-struct pkey (algo evp private?))
(define/ffi (d2i_PublicKey _int (_pointer = #f) (_ptr i _pointer) _long)
-> _pointer : pointer/error)
(define/ffi (d2i_PrivateKey _int (_pointer = #f) (_ptr i _pointer) _long)
-> _pointer : pointer/error)
(define/ffi (i2d_PublicKey _pointer (_ptr i _pointer)) -> _int : int/error)
(define/ffi (i2d_PrivateKey _pointer (_ptr i _pointer)) -> _int : int/error)
(define i2d_PublicKey-length
(lambda/ffi (i2d_PublicKey _pointer (_pointer = #f))
-> _int : int/error))
(define i2d_PrivateKey-length
(lambda/ffi (i2d_PrivateKey _pointer (_pointer = #f))
-> _int : int/error))
(define (pkey-size pk)
(EVP_PKEY_size (pkey-evp pk)))
(define (pkey-bits pk)
(EVP_PKEY_bits (pkey-evp pk)))
(define (pkey=? k1 k2 . klst)
(let ((evp (pkey-evp k1)))
(let lp ((lst (cons k2 klst)))
(cond
((null? lst) #t)
((EVP_PKEY_cmp evp (pkey-evp (car lst))) (lp (cdr lst)))
(else #f)))))
(define (read-pkey algo public? ibs ilen)
(let* ((d2i (if public? d2i_PublicKey d2i_PrivateKey))
(evp (d2i (pkey:algo-type algo) ibs ilen))
(pk (make-pkey algo evp (not public?))))
(register-finalizer pk (compose EVP_PKEY_free pkey-evp))
pk))
(define (write-pkey pk public? . args)
(let*-values
(((i2d i2d-len)
(if public?
(values i2d_PublicKey i2d_PublicKey-length)
(values i2d_PrivateKey i2d_PrivateKey-length)))
((obs)
(match args
((list obs olen)
(if (< olen (i2d-len (pkey-evp pk)))
(error 'write-pkey "buffer too small")
obs))
((list)
(make-bytes (i2d-len (pkey-evp pk)))))))
(values obs (i2d (pkey-evp pk) obs))))
(define-syntax define-bytes->pkey
(syntax-rules ()
((_ op public?)
(define (op algo bs)
(read-pkey algo public? bs (bytes-length bs))))))
(define-syntax define-pkey->bytes
(syntax-rules ()
((_ op public?)
(define (op pk)
(call/values
(lambda () (write-pkey pk public?))
shrink-bytes)))))
(define-pkey->bytes private-key->bytes #f)
(define-bytes->pkey bytes->private-key #f)
(define-pkey->bytes public-key->bytes #t)
(define-bytes->pkey bytes->public-key #t)
(define (pkey->public-key pk)
(if (pkey-private? pk)
(bytes->public-key (pkey-algo pk) (public-key->bytes pk))
pk))
(define (pk->type evp)
(EVP_PKEY_type (car (ptr-ref evp (_list-struct _int)))))
(define (evp->pkey evp pkt pkp)
(EVP_PKEY_assign evp (pkey:algo-type pkt) pkp)
(let ((pk (make-pkey pkt evp #t)))
(register-finalizer pk (compose EVP_PKEY_free pkey-evp)) pk))
(define/kw (rsa-keygen bits #:key (exponent 65537))
(let/fini ((ep (BN_new) BN_free))
(BN_add_word ep exponent)
(let/error-fini ((rsap (RSA_new) RSA_free)
(evp (EVP_PKEY_new) EVP_PKEY_free))
(RSA_generate_key_ex rsap bits ep)
(evp->pkey evp pkey:rsa rsap))))
(define pkey:rsa
(with-handlers* ((exn:fail? (lambda x #f)))
(let/fini ((rsap (RSA_new) RSA_free)
(evp (EVP_PKEY_new) EVP_PKEY_free))
(EVP_PKEY_set1_RSA evp rsap)
(make-pkey:algo (pk->type evp) rsa-keygen))))
(define (dsa-keygen bits)
(let/error-fini ((dsap (DSA_new) DSA_free)
(evp (EVP_PKEY_new) EVP_PKEY_free))
(DSA_generate_parameters_ex dsap bits)
(DSA_generate_key dsap)
(evp->pkey evp pkey:dsa dsap)))
(define pkey:dsa
(with-handlers* ((exn:fail? (lambda x #f)))
(let/fini ((dsap (DSA_new) DSA_free)
(evp (EVP_PKEY_new) EVP_PKEY_free))
(EVP_PKEY_set1_DSA evp dsap)
(make-pkey:algo (pk->type evp) dsa-keygen))))
(define (generate-pkey algo bits . args)
(apply (pkey:algo-keygen algo) bits args))
(define (pkey-sign dg pk bs)
(unless (pkey-private? pk)
(error 'sign "not a private key"))
(cond
((digest-ctx dg) =>
(lambda (ctx)
(EVP_SignFinal ctx bs (pkey-evp pk))))
(else (error 'pkey-sign "finalized context"))))
(define (pkey-verify dg pk bs len)
(cond
((digest-ctx dg) =>
(lambda (ctx)
(EVP_VerifyFinal ctx bs len (pkey-evp pk))))
(else (error 'pkey-verify "finalized context"))))
(define digest-sign
(case-lambda
((dg pk)
(digest-sign dg pk (make-bytes (pkey-size pk))))
((dg pk bs)
(check-output-range 'digest-sign bs (pkey-size pk))
(values bs (pkey-sign dg pk bs)))
((dg pk bs start end)
(check-output-range 'digest-sign bs start end (pkey-size pk))
(values bs (pkey-sign dg pk (ptr-add bs start))))))
(define digest-verify
(case-lambda
((dg pk bs)
(pkey-verify dg pk bs (bytes-length bs)))
((dg pk bs start end)
(check-input-range 'digest-verify bs start end)
(pkey-verify dg pk (ptr-add bs start) (- end start)))))
(define (sign-bytes dgalgo pk bs)
(let ((dg (digest-new dgalgo)))
(digest-update! dg bs)
(call/values
(lambda () (digest-sign dg pk))
shrink-bytes)))
(define (verify-bytes dgalgo pk sigbs bs)
(let ((dg (digest-new dgalgo)))
(digest-update! dg bs)
(digest-verify dg pk sigbs)))
(define (sign-port dgalgo pk inp)
(let ((dg (port->digest dgalgo inp)))
(call/values
(lambda () (digest-sign dg pk))
shrink-bytes)))
(define (verify-port dgalgo pk sigbs inp)
(digest-verify (port->digest dgalgo inp) pk sigbs))
(define (sign pk dgalgo inp)
(unless (pkey-digest? pk dgalgo)
(error 'sign "incompatible digest type"))
(cond
((bytes? inp) (sign-bytes dgalgo pk inp))
((input-port? inp) (sign-port dgalgo pk inp))
(else (raise-type-error 'sign "bytes or input-port" inp))))
(define (verify pk dgalgo sigbs inp)
(unless (pkey-digest? pk dgalgo)
(error 'verify "incompatible digest type"))
(cond
((bytes? inp) (verify-bytes dgalgo pk sigbs inp))
((input-port? inp) (verify-port dgalgo pk sigbs inp))
(else (raise-type-error 'verify "bytes or input-port" inp))))
(define-syntax define-pkey-crypt
(syntax-rules ()
((_ crypt op evp-op public?)
(begin
(define (op pk ibs ilen)
(unless (or public? (pkey-private? pk))
(error 'crypt "not a private key"))
(let* ((obs (make-bytes (pkey-size pk)))
(olen (evp-op obs ibs ilen (pkey-evp pk))))
(shrink-bytes obs olen)))
(define crypt
(case-lambda
((pk ibs)
(check-input-range 'crypt ibs (pkey-size pk))
(op pk ibs (bytes-length ibs)))
((pk ibs istart iend)
(check-input-range 'crypt ibs istart iend (pkey-size pk))
(op pk (ptr-add ibs istart) (- iend istart)))))))))
(define-pkey-crypt encrypt/pkey pkey-encrypt EVP_PKEY_encrypt #t)
(define-pkey-crypt decrypt/pkey pkey-decrypt EVP_PKEY_decrypt #f)
(define (encrypt/envelope pk cipher . cargs)
(let*-values (((k iv) (generate-cipher-key cipher))
((sk) (encrypt/pkey pk k)))
(call/values
(lambda () (apply encrypt cipher k iv cargs))
(lambda cvals (apply values sk iv cvals)))))
(define (decrypt/envelope pk cipher sk iv . cargs)
(apply decrypt cipher (decrypt/pkey pk sk) iv cargs))
(define *pkey-digests* null)
(define (pkey-digest? pk dgalgo)
(cond
((pkey? pk) (pkey-digest? (pkey-algo pk) dgalgo))
((pkey:algo? pk)
(and-let* ((ds (assq pk *pkey-digests*)))
(and (memq dgalgo (cdr ds)) #t)))
(else (raise-type-error 'pkey-digest? "pkey or pkey algorithm" pk))))
(define-syntax (define-pkey stx)
(define (make-digests dgs)
(let ((dgts (map (lambda (dg) (make-symbol "digest:" dg)) dgs)))
(->stx stx (list* 'list dgts))))
(syntax-case stx ()
((_ pk (dg ...))
(let ((pkt (->datum #'pk))
(dgts (map ->datum (syntax->list #'(dg ...)))))
(with-syntax
((algo (->stx stx (make-symbol "pkey:" pkt)))
(digests (->stx stx (make-symbol "pkey:" pkt ":digests")))
(provider (->stx stx (make-symbol "provide:pkey:" pkt))))
#`(begin
(define digests #,(make-digests dgts))
(define-syntax provider
(syntax-rules ()
((_) (provide algo digests))))
(push! *pkey-digests* (cons algo digests))))))))
(define-pkey rsa (ripemd160 sha1 sha224 sha256 sha384 sha512))
(define-pkey dsa (dss1))
(define-syntax provide:pkey
(syntax-rules ()
((_)
(begin
(provide pkey? pkey-private? pkey-size pkey-bits pkey=? pkey-digest?
pkey->public-key public-key->bytes bytes->public-key
private-key->bytes bytes->private-key
digest-sign digest-verify
sign verify
encrypt/pkey decrypt/pkey
encrypt/envelope decrypt/envelope)
(provide:pkey:rsa)
(provide:pkey:dsa)))))
)