digest.ss
;; mzcrypto: crypto library for mzscheme
;; Copyright (C) 2007 Dimitris Vyzovitis <vyzo@media.mit.edu>
;;
;; This library is free software; you can redistribute it and/or
;; modify it under the terms of the GNU Lesser General Public
;; License as published by the Free Software Foundation; either
;; version 2.1 of the License, or (at your option) any later version.
;;
;; This library is distributed in the hope that it will be useful,
;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
;; Lesser General Public License for more details.
;;
;; You should have received a copy of the GNU Lesser General Public
;; License along with this library; if not, write to the Free Software
;; Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
;; USA

(module digest mzscheme
  (require-for-syntax "stx-util.ss")
  (require (lib "foreign.ss")
           (only (lib "list.ss" "srfi" "1") last))
  (require "libcrypto.ss" "error.ss")
  (provide (all-defined))
  
  (define/ffi (EVP_MD_CTX_create) -> _pointer : pointer/error)
  (define/ffi (EVP_DigestInit_ex _pointer _pointer (_pointer = #f))
    -> _int : check-error)
  (define/ffi (EVP_DigestUpdate _pointer _pointer _ulong)
    -> _int : check-error)
  (define/ffi (EVP_DigestFinal_ex _pointer _pointer (_pointer = #f))
    -> _int : check-error)
  (define/ffi (EVP_MD_CTX_copy_ex _pointer _pointer)
    -> _int : check-error)
  (define/ffi (EVP_MD_CTX_destroy _pointer))
  (define/ffi (HMAC _pointer _pointer _int _pointer _int 
                    _pointer (r : (_ptr o _uint)))
    -> _pointer : (lambda x r))
  (define/ffi (HMAC_CTX_init _pointer))
  (define/ffi (HMAC_CTX_cleanup _pointer))
  (define/ffi (HMAC_Init_ex _pointer _pointer _uint _pointer (_pointer = #f)))
  (define/ffi (HMAC_Update _pointer _pointer _uint))
  (define/ffi (HMAC_Final _pointer _pointer (r : (_ptr o _int)))
    -> _void : (lambda x r))

  (define-struct digest:algo (evp size)) 
  (define-struct digest (type ctx))
  
  (define (digest-size o)
    (cond
     ((digest:algo? o) (digest:algo-size o))
     ((digest? o) (digest:algo-size (digest-type o)))
     (else (raise-type-error 'digest-size "digest or digest algorithm" o))))
  
  (define (digest-new type)
    (let* ((evp (digest:algo-evp type))
           (dg (make-digest type (EVP_MD_CTX_create))))
      (register-finalizer dg
        (lambda (o) (cond ((digest-ctx o) => EVP_MD_CTX_destroy))))
      (EVP_DigestInit_ex (digest-ctx dg) (evp))
      dg))

  (define (digest-update dg bs len)
    (cond
     ((digest-ctx dg) => 
      (lambda (ctx) (EVP_DigestUpdate ctx bs len)))
     (else (error 'digest-update "finalized context"))))

  (define digest-update!
    (case-lambda
      ((dg data)
       (digest-update dg data (bytes-length data)))
      ((dg data start end)
       (check-input-range 'digest-update data start end)
       (digest-update dg (ptr-add data start) (- end start)))))

  (define (digest-final dg bs)
    (cond
     ((digest-ctx dg) =>
      (lambda (ctx)
        (EVP_DigestFinal_ex ctx bs)
        (EVP_MD_CTX_destroy ctx)
        (set-digest-ctx! dg #f)))
     (else (error 'digest-final "finalized context"))))

  (define digest-final! 
    (case-lambda
      ((dg)
       (digest-final! dg (make-bytes (digest-size dg))))
      ((dg bs)
       (check-output-range 'digest-final bs (digest-size dg))
       (digest-final dg bs)
       (values bs (digest-size dg)))
      ((dg bs start end)
       (check-output-range 'digest-final bs start end (digest-size dg))
       (digest-final dg (ptr-add bs start))
       (values bs (digest-size dg)))))
  
  (define (digest-copy idg)
    (cond
     ((digest-ctx idg) =>
      (lambda (ictx)
        (let ((odg (digest-new (digest-type idg))))
          (EVP_MD_CTX_copy_ex (digest-ctx odg) ictx)
          odg)))
     (else (error 'digest-copy "finalized context"))))

  (define (digest->bytes dg)
    (let-values (((bs count) (digest-final! (digest-copy dg))))
      bs))

  (define (port->digest algo inp)
    (let ((dg (digest-new algo))
          (ibuf (make-bytes (digest-size algo))))
      (let lp ((count (read-bytes-avail! ibuf inp)))
        (if (eof-object? count)
            dg
            (begin
              (digest-update! dg ibuf 0 count)
              (lp (read-bytes-avail! ibuf inp)))))))

  (define (hash-port algo inp)
    (let-values (((bs count) (digest-final! (port->digest algo inp))))
      bs))
  
  (define (hash-bytes algo bs)
    (let ((dg (digest-new algo)))
      (digest-update! dg bs)
      (let-values (((dbs count) (digest-final! dg))) 
        dbs)))
  
  (define (hash algo inp)
    (cond
     ((bytes? inp) (hash-bytes algo inp))
     ((input-port? inp) (hash-port algo inp))
     (else (raise-type-error 'hash "bytes or input-port" inp))))

  (define (hmac-bytes algo kbs ibs)
    (let ((evp (digest:algo-evp algo))
          (obs (make-bytes (digest:algo-size algo))))
      (HMAC (evp) kbs (bytes-length kbs) ibs (bytes-length ibs) obs)
      obs))

  (define (make-hmac-ctx)
    (let ((ctx (make-bytes 256))) ; ugh - no HMAC_CTX* maker in libcrypto
      (HMAC_CTX_init ctx)
      ctx))
  (define cleanup-hmac-ctx HMAC_CTX_cleanup)

  (define (hmac-port algo k inp)
    (let ((evp (digest:algo-evp algo))
          (buf (make-bytes (digest:algo-size algo))))
      (let/fini ((ctx (make-hmac-ctx) cleanup-hmac-ctx))
        (HMAC_Init_ex ctx k (bytes-length k) (evp))
        (let lp ((count (read-bytes-avail! buf inp)))
          (if (eof-object? count)
              (begin
                (HMAC_Final ctx buf)
                buf)
              (begin
                (HMAC_Update ctx buf count)
                (lp (read-bytes-avail! buf inp))))))))
  
  (define (hmac algo key inp)
    (cond
     ((bytes? inp) (hmac-bytes algo key inp))
     ((input-port? inp) (hmac-port algo key inp))
     (else (raise-type-error 'hmac "bytes or input-port" inp))))

  ;; specifics
  ;; EVP_MD: struct evp_md_st {type pkey_type md_size ...}
  (define (md->size evp)
    (last (ptr-ref evp (_list-struct _int _int _int))))

  (define *digests* null)
  (define (available-digests) *digests*)
  
  (define-syntax (define-digest stx)
    (syntax-case stx ()
      ((_ digest df? ...)
       (let ((name (->string (->datum #'digest))))
         (with-syntax
             ((fdigest (->stx stx (->symbol (string-upcase name))))
              (evp (->stx stx (make-symbol "EVP_" name)))
              (size (->stx stx (make-symbol "digest:" name ":size")))
              (algo (->stx stx (make-symbol "digest:" name)))
              (provider (->stx stx (make-symbol "provide:digest:" name))))
         #`(begin
             (define/ffi (evp) -> _pointer : pointer/error)
             (define size 
               (with-handlers* ((exn:fail? (lambda x #f)))
                 (md->size (evp))))
             (define algo (and size (make-digest:algo evp size)))
             (define digest (and algo (lambda (inp) (hash algo inp))))
             (when digest (push! *digests* 'digest))
             (define-syntax provider
               (syntax-rules ()
                 ((_) (provide algo digest))))))))))

  (define-digest md5)
  (define-digest ripemd160)
  (define-digest dss1) ; sha1...
  (define-digest sha1)
  (define-digest sha224)
  (define-digest sha256)
  (define-digest sha384)
  (define-digest sha512)
 
  ;; public api
  (define-syntax provide:digest
    (syntax-rules ()
      ((_) 
       (begin
         (provide available-digests digest? digest-new digest-size
                  digest-update! digest-final! digest-copy digest->bytes
                  hash hmac)
         (provide:digest:md5)
         (provide:digest:dss1)
         (provide:digest:sha1)
         (provide:digest:sha224)
         (provide:digest:sha256)
         (provide:digest:sha384)
         (provide:digest:sha512)
         (provide:digest:ripemd160)))))
)