hierarchy.ss
#lang scheme
(require
 srfi/26)

(define-struct hierarchy
  (parents
   children
   ancestors
   descendants))

(define (make-hierarchy*)
  (make-hierarchy #hash() #hash() #hash() #hash()))

(define global-hierarchy
  (make-parameter (make-hierarchy*)))

(provide/contract
 [hierarchy? (any/c . -> . boolean?)]
 [rename make-hierarchy* make-hierarchy (-> hierarchy?)]
 [global-hierarchy (parameter/c hierarchy?)])

(define-struct (exn:fail:hierarchy exn:fail)
  (child parent)
  #:transparent)

(provide/contract
 [struct (exn:fail:hierarchy exn:fail)
   ([message string?]
    [continuation-marks continuation-mark-set?]
    [child any/c]
    [parent any/c])])

(define parents
  (case-lambda
    [(v)
     (parents (global-hierarchy) v)]
    [(h v)
     (hash-ref (hierarchy-parents h) v #hash())]))

(define children
  (case-lambda
    [(v)
     (children (global-hierarchy) v)]
    [(h v)
     (hash-ref (hierarchy-children h) v #hash())]))

(define ancestors
  (case-lambda
    [(v)
     (ancestors (global-hierarchy) v)]
    [(h v)
     (hash-ref (hierarchy-ancestors h) v #hash())]))

(define descendants
  (case-lambda
    [(v)
     (descendants (global-hierarchy) v)]
    [(h v)
     (hash-ref (hierarchy-descendants h) v #hash())]))

(define derived?
  (case-lambda
    [(child parent)
     (derived? (global-hierarchy) child parent)]
    [(h child parent)
     (let ([parent (if (class? parent) (class->interface parent) parent)]
           [child (if (class? child) (class->interface child) child)])
       (or
        (equal? child parent)
        (and (interface? parent)
             (interface-extension? child parent))
        (hash-ref (ancestors h child) parent #f)
        (and (interface? child)
             (for/or ([candidate (in-hash-keys (descendants parent))]
                      #:when (interface? candidate))
               (interface-extension? child candidate)))
        (and (dict? child) (dict? parent)
             (let/ec esc
               (local [(define (fail)
                         (esc #f))]
                 (for/and ([(key value) (in-dict parent)])
                   (derived? (dict-ref child key fail) value)))))))]))

(define (hash-merge a b)
  (let-values ([(a b) (if (>= (hash-count a) (hash-count b))
                          (values a b)
                          (values b a))])
    (for/fold ([acc a]) ([(key value) (in-hash b)])
      (hash-set acc key value))))

(define derive
  (case-lambda
    [(child parent)
     (global-hierarchy (derive (global-hierarchy) child parent))]
    [(h child parent)
     (let ([child (if (class? child) (class->interface child) child)])
       (match h
         [(struct hierarchy (parents children ancestors descendants))
          (cond
            [(equal? child parent)
             (raise (make-exn:fail:hierarchy
                     (format
                      "derive: ~e cannot derive from itself"
                      child)
                     (current-continuation-marks)
                     child parent))]
            [(hash-ref (hash-ref ancestors parent #hash()) child #f)
             (raise (make-exn:fail:hierarchy
                     (format
                      "derive: ~e would derive from itself via ~e"
                      child parent)
                     (current-continuation-marks)
                     child parent))]
            [(hash-ref (hash-ref ancestors child #hash()) parent #f)
             (raise (make-exn:fail:hierarchy
                     (format
                      "derive: ~e already derives from ~e"
                      child parent)
                     (current-continuation-marks)
                     child parent))]
            [else
             (local [(define (update-relation rel trel from to recur?)
                       (let ([rel (hash-update rel from
                                    (compose
                                     (cut hash-merge <> (hash-ref
                                                         rel to #hash()))
                                     (cut hash-set <> to #t))
                                    #hash())])
                         (if recur?
                             (for/fold
                                 ([rel rel])
                                 ([from (in-hash-keys
                                         (hash-ref trel from #hash()))])
                               (update-relation rel trel from to #f))
                             rel)))]
               (make-hierarchy
                (hash-update parents child
                  (cut hash-set <> parent #t)
                  #hash())
                (hash-update children parent
                  (cut hash-set <> child #t)
                  #hash())
                (update-relation
                 ancestors descendants child parent #t)
                (update-relation
                 descendants ancestors parent child #t)))])]))]))

(define underive
  (case-lambda
    [(child parent)
     (global-hierarchy (underive (global-hierarchy) child parent))]
    [(h child parent)
     (let ([child (if (class? child) (class->interface child) child)])
       (match h
         [(struct hierarchy (parents children ancestors descendants))
          (cond
            [(equal? child parent)
             (raise (make-exn:fail:hierarchy
                     (format
                      "underive: cannot underive ~e from itself"
                      child)
                     (current-continuation-marks)
                     child parent))]
            [(not (hash-ref (hash-ref ancestors child #hash()) parent #f))
             (raise (make-exn:fail:hierarchy
                     (format
                      "underive: ~e doesn't derive from ~e"
                      child parent)
                     (current-continuation-marks)
                     child parent))]
            [else
             (local [(define new-parents
                       (hash-update parents child
                         (cut hash-remove <> parent)))
                     (define new-children
                       (hash-update children parent
                         (cut hash-remove <> child)))
                     (define (rebuild-relation rel trel from via recur?)
                       (let ([rel (hash-set rel from
                                    (for/fold
                                        ([acc #hash()])
                                        ([to (in-hash-keys
                                              (hash-ref via from #hash()))])
                                      ((compose
                                        (cut hash-merge
                                             <> (hash-ref rel to #hash()))
                                        (cut hash-set <> to #t))
                                       acc)))])
                         (if recur?
                             (for/fold
                                 ([rel rel])
                                 ([from (in-hash-keys
                                         (hash-ref trel from #hash()))])
                               (rebuild-relation rel trel from via #f))
                             rel)))]
               (make-hierarchy
                new-parents
                new-children
                (rebuild-relation
                 ancestors descendants child new-parents #t)
                (rebuild-relation
                 descendants ancestors parent new-children #t)))])]))]))

(define no-class+interface/c
  (not/c (or/c class? interface?)))

(provide/contract
 [parents (case->
           (any/c . -> . hash?)
           (hierarchy? any/c . -> . hash?))]
 [children (case->
            (any/c . -> . hash?)
            (hierarchy? any/c . -> . hash?))]
 [ancestors (case->
             (any/c . -> . hash?)
             (hierarchy? any/c . -> . hash?))]
 [descendants (case->
               (any/c . -> . hash?)
               (hierarchy? any/c . -> . hash?))]
 [derived? (case->
            (any/c any/c . -> . boolean?)
            (hierarchy? any/c any/c . -> . boolean?))]
 [derive (case->
          (any/c no-class+interface/c . -> . void?)
          (hierarchy? any/c no-class+interface/c . -> . hierarchy?))]
 [underive (case->
            (any/c no-class+interface/c . -> . void?)
            (hierarchy? any/c no-class+interface/c . -> . hierarchy?))])