basic-blocks.rkt
#lang racket/base

(require racket/match 
         racket/set
         racket/list)


(provide NEXT
         DYNAMIC
         (struct-out bblock)
         fracture
         jump-type?)


;; Basic blocks
(define-struct bblock (name   ;; symbol
                       entry? ;; boolean
                       stmts  ;; (listof stmt)
                       succs  ;; (setof (U symbol DYNAMIC))
                       next-succ ;; (U #f symbol)
                       )
  #:transparent)


(define-struct jump-type ())
(define-struct (next jump-type) ())
(define-struct (dynamic jump-type) ())
(define NEXT (make-next))
(define DYNAMIC (make-dynamic))


;; A label is also defined to be a Label.  This is an internal
;; definition that won't be exposed outside, since it's only used for
;; fresh labels named with fresh-block-names.
(define-struct Label (name))




;; fracture: (listof (U stmt label)) -> (listof bblock)
;; Given a sequence of statements and labels, as well as the names of entry points,
;; returns a list of bblocks.
(define (fracture stmts
                  #:entry-names (entry-names '())
                  #:fresh-block-name (fresh-block-name default-fresh-block-name)
                  #:label? (external-label? default-label?)
                  #:label-name (external-label-name default-label-name)
                  #:jump? (jump? default-jump?)
                  #:jump-targets (jump-targets default-jump-targets))

  (define (label? x)
    (or (Label? x)
        (external-label? x)))

  (define (label-name x)
    (if (Label? x)
        (Label-name x)
        (external-label-name x)))
  
  
  (check-good-stmts! stmts label?)
  
  
  ;; Main loop.  Watch for leaders.
  (let-values ([(entry-names-set) (list->set (cons (label-name (first stmts)) entry-names))]
               [(leaders stmts)
                (find/inject-leaders stmts entry-names jump? jump-targets
                                     label? label-name fresh-block-name)])
    
    
    ;; leader?: stmt -> boolean
    ;; Returns true if the statement is a leader.
    (define (leader? stmt)
      (and (label? stmt) (set-member? leaders (label-name stmt))))
    
    ;; Skip statements until we hit the next leader.
    (define (skip-till-leader stmts)
      (cond
        [(empty? stmts) '()]
        [(leader? (first stmts))
         stmts]
        [else
         (skip-till-leader (rest stmts))]))
    
    (filter-reachable
     (let loop ([bblocks '()]
                [pending-block-name (label-name (first stmts))]
                [pending-stmts/rev (list)]
                [pending-jump-targets (set)]
                [pending-next-succ #f]
                [stmts (rest stmts)])
       (cond
         [(empty? stmts)
          (reverse (cons (make-bblock pending-block-name
                                      (set-member? entry-names-set pending-block-name)
                                      (reverse pending-stmts/rev)
                                      pending-jump-targets
                                      pending-next-succ)
                         bblocks))]
         [(leader? (first stmts))
          (loop (cons (make-bblock pending-block-name
                                   (set-member? entry-names-set pending-block-name)
                                   (reverse pending-stmts/rev)
                                   pending-jump-targets
                                   pending-next-succ)
                      bblocks)
                (label-name (first stmts))
                (list)
                (set)
                #f
                (rest stmts))]
         
         [else
          (loop bblocks 
                pending-block-name 
                ;; Omit dead labels.
                (if (label? (first stmts))
                    pending-stmts/rev 
                    (cons (first stmts) pending-stmts/rev))
                (cond [(jump? (first stmts))
                       (set-union (list->set (map (lambda (t)
                                                    (cond [(eq? t NEXT)
                                                           (label-name (second stmts))]
                                                          [(eq? t DYNAMIC)
                                                           DYNAMIC]
                                                          [else t]))
                                                  (jump-targets (first stmts))))
                                  pending-jump-targets)]
                      [(and (not (null? (rest stmts)))
                            (leader? (second stmts)))
                       (set-union pending-jump-targets (set (label-name (second stmts))))]
                      [else
                       pending-jump-targets])
                (cond [(and (jump? (first stmts))
                            (memq NEXT (jump-targets (first stmts))))
                       (label-name (second stmts))]
                      [(and (not (jump? (first stmts)))
                            (not (empty? (rest stmts)))
                            (leader? (second stmts)))
                       (label-name (second stmts))]
                      [else #f])
                
                (cond [(and (jump? (first stmts))
                            (not (memq NEXT (jump-targets (first stmts)))))
                       ;; After a jump, skip till we hit a leader
                       (skip-till-leader (rest stmts))]
                      [else
                       (rest stmts)]))])))))


;; Make sure we get a good list of statements for fracture.
(define (check-good-stmts! stmts label?)
  (match stmts
    [(list (? label?) rest ...)
     (void)]
    [else
     (raise-type-error 'fracture "nonempty list of statements beginning with a label" stmts)]))



;; find/inject-leaders: -> (values (setof symbol) (listof stmt))
;; Preprocesses the statements and computes leaders, and injects them if necessary.
(define (find/inject-leaders stmts entry-names jump? jump-targets label? label-name fresh-block-name)
  (let loop ([leaders (cons (label-name (first stmts)) entry-names)]
             [stmts-seen/rev (list (first stmts))]
             [stmts-to-see (rest stmts)])
    (cond
      [(empty? stmts-to-see)
       (values (list->set leaders) (reverse stmts-seen/rev))]
      [(jump? (first stmts-to-see))
       (define targets (jump-targets (first stmts-to-see)))
       (define named-targets
         (filter (lambda (t)
                   (and (not (eq? t NEXT))
                        (not (eq? t DYNAMIC))))
                 targets))
       (cond [(member NEXT targets)
              (cond
                [(or (empty? (rest stmts-to-see))
                     (not (label? (second stmts-to-see))))
                 (define fresh-stmt (make-Label (fresh-block-name)))
                 (loop (append named-targets (cons (label-name fresh-stmt) leaders))
                       (cons fresh-stmt (cons (first stmts-to-see) stmts-seen/rev))
                       (rest stmts-to-see))]
                [else
                 (loop (cons (second stmts-to-see) (append named-targets leaders))
                       (cons (first stmts-to-see) stmts-seen/rev)
                       (rest stmts-to-see))])]
             [else
              (loop (append named-targets leaders)
                    (cons (first stmts-to-see) stmts-seen/rev)
                    (rest stmts-to-see))])]
      [else
       (loop leaders 
             (cons (first stmts-to-see) stmts-seen/rev)
             (rest stmts-to-see))])))




;; Given a sequence of basic blocks, returns the sequence of those
;; reachable by starting at an entry block and jumping.  Basic DFS.
(define (filter-reachable bblocks)
  (define ht (make-hasheq))
  (for ([b bblocks])
    (hash-set! ht (bblock-name b) b))
  
  (define visited (make-hasheq))  
  (define (dfs queue)
    (cond
      [(empty? queue)
       (void)]
      [(hash-has-key? visited (first queue))
       (dfs (rest queue))]
      [else
       (hash-set! visited (first queue) #t)
       (dfs (append (for/list ([neighbor (bblock-succs (hash-ref ht (first queue)))]
                               #:when (and (not (eq? neighbor DYNAMIC))
                                           (not (hash-has-key? visited neighbor))))
                      neighbor)
                    (rest queue)))]))
  (dfs (for/list ([b bblocks]
                  #:when (bblock-entry? b))
         (bblock-name b)))
  
  
  (for/list ([b bblocks]
             #:when (hash-has-key? visited (bblock-name b)))
    b))




;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define (default-label? x)
  (symbol? x))


(define (default-label-name a-label)
  (cond
    [(symbol? a-label) a-label]
    [else
     (raise-type-error 'default-label-name "symbol" a-label)]))

(define (default-fresh-block-name)
  (gensym 'label))

(define (default-jump? x)
  (match x
    [(list 'goto target)
     #t]
    [(list 'if condition 'goto target)
     #t]
    [else #f]))

(define (default-jump-targets x)
  (match x
    [(list 'goto target)
     (cond [(symbol? target)
            (list target)]
           [else
            (list DYNAMIC)])]
    [(list 'if condition 'goto target)
     (list (cond [(symbol? target)
                  target]
                 [else
                  DYNAMIC])
           NEXT)]
    [else
     (raise-type-error 'default-jump-targets "Statement with jump targets" x)]))