set.ss
#lang scheme
(require "private/common.ss")
(provide set? list->set set->list empty-set set-empty? set-count
         set-intersection set-difference set-partition set-union set-xor
         set-intersections set-differences set-partitions set-unions set-xors
         set-adjoin set-add set-contains?
         subset? set=?
         for/set for*/set
         in-set)

(define (in-set set)
  (in-hash-keys (set-elts set)))

(define-struct set (elts)
  #:property prop:custom-write (lambda (set port write?)
                                 (write-hash "set" (set-elts set) port write?))
  #:property prop:sequence in-set)

(define (list->set ls)
  (make-set (for/hash ([x ls])
              (values x #t))))

(define (set->list set)
  (for/list ([(key value) (set-elts set)])
    key))

(define (set-intersection set . sets)
  (make-set (hash-intersection (set-elts set) (map set-elts sets) for/hash)))

(define (set-intersections sets)
  (make-set (hash-intersection (set-elts (car sets)) (map set-elts (cdr sets)) for/hash)))

(define (set-difference set . sets)
  (make-set (hash-difference (set-elts set) (map set-elts sets) for/hash)))

(define (set-differences sets)
  (make-set (hash-difference (set-elts (car sets)) (map set-elts (cdr sets)) for/hash)))

(define (set-partition set . sets)
  (let-values ([(diff intersection) ((hash-partition #hash()) (set-elts set) (map set-elts sets))])
    (values (make-set diff) (make-set intersection))))

(define (set-partitions sets)
  (let-values ([(diff intersection) ((hash-partition #hash()) (set-elts (car sets)) (map set-elts (cdr sets)))])
    (values (make-set diff) (make-set intersection))))

(define empty-set (make-set #hash()))

(define (set-empty? set)
  (zero? (hash-count (set-elts set))))

(define (set-count set)
  (hash-count (set-elts set)))

(define (set-unions sets)
  (make-set (foldr union1 #hash() (map set-elts sets))))

(define (set-union . sets)
  (set-unions sets))

(define (set-xor . sets)
  (set-xors sets))

(define (set-xors sets)
  (make-set (foldr (xor1 #hash()) #hash() (map set-elts sets))))

(define (set-adjoin set . elts)
  (set-union set (list->set elts)))

(define (set-add elt set)
  (set-adjoin set elt))

(define (set-contains? set elt)
  (hash-ref (set-elts set) elt (lambda () #f)))

(define-syntax-rule (for/set (for-clause ...) body0 body ...)
  (make-set (for/hash (for-clause ...)
              (values (let () body0 body ...) #t))))
  
(define-syntax-rule (for*/set (for-clause ...) body0 body ...)
  (make-set (for*/hash (for-clause ...)
              (values (let () body0 body ...) #t))))

(define (subset? . sets)
  (let loop ([hashes (map set-elts sets)])
    (match hashes
      [(cons hash1 (and hashes (cons hash2 _)))
       (and (<=?1 hash1 hash2) (loop hashes))]
      [_ #t])))

(define (set=? . sets)
  (let loop ([hashes (map set-elts sets)])
    (match hashes
      [(cons hash1 (and hashes (cons hash2 _)))
       (and (=?1 hash1 hash2) (loop hashes))]
      [_ #t])))