#lang racket
(require (only-in "dimacs.rkt"
dimacs-lit->literal
dimacs-lits->clause)
"data-structures.rkt"
"learned-clauses.rkt"
"statistics.rkt"
"sat-heuristics.rkt"
"smt-interface.rkt"
"debug.rkt"
rackunit)
(provide propagate-assignment
resolve-conflict!
initial-bcp
obliterate-lit!)
(define (bcp-clause smt clause)
(if (and (literal-eq? (clause-watched1 clause) (clause-watched2 clause))
(literal-unassigned? (clause-watched1 clause)))
(begin (propagate-assignment smt (clause-watched1 clause) clause))
smt))
(define (initial-bcp smt)
(let ((clauses (SMT-clauses smt)))
(let recur ((smt smt)
(idx 0))
(if (= idx (vector-length clauses))
smt
(recur (bcp-clause smt (vector-ref clauses idx))
(+ 1 idx))))))
(define (learned-bcp smt learned)
(propagate-assignment smt (clause-watched1
(learned-clause-clause learned))
learned))
(define (propagate-assignment smt lit clause)
(let ((smt (SMT-satisfy-literal! smt lit)))
(set-literal-igraph-node! lit
(node (SMT-decision-level smt) clause))
(let* ((smt (propagate-T-implications smt lit))
(falsify (negate-literal lit)))
(let prop-watch ((itr (watched-iterate-first falsify))
(smt smt))
(if itr
(let ((watching-clause (watched-iterate-clause itr)))
(if (clause-forgotten? watching-clause)
(prop-watch (watched-iterate-remove itr) smt)
(let-values ([(smt remove?)
(update-watchedness smt watching-clause falsify)])
(if remove?
(prop-watch (watched-iterate-remove itr) smt)
(prop-watch (watched-iterate-next itr) smt)))))
smt)))))
(define (propagate-T-implications smt lit)
(let-values ([(t-state lits)
((T-Propagate) (SMT-T-State smt) (SMT-strength smt) (literal->dimacs lit))])
(let t-propagate ((lits lits)
(smt (new-T-State smt t-state)))
(if (empty? lits) smt
(t-propagate
(cdr lits)
(propagate-assignment
smt
((dimacs-lit->literal (SMT-variables smt)) (car lits))
(lambda (smt) (clause-literals
((dimacs-lits->clause (SMT-variables smt))
((T-Explain) (SMT-T-State smt) (SMT-strength smt) (car lits)))))))))))
(define (satisfy-literal! sat literal)
(begin (set-var-value! (literal-var literal) (literal-polarity literal))
(set-var-timestamp! (literal-var literal)
(SAT-Stats-assigned-order (SAT-statistics sat)))
(add-to-current-decision-level (SAT-inc-assigned-order sat) literal)))
(define (SMT-satisfy-literal! smt literal)
(SMT (satisfy-literal! (SMT-sat smt) literal)
((T-Satisfy) (SMT-T-State smt) (literal->dimacs literal))
(SMT-strength smt)
(SMT-seed smt)))
(define (backjump! smt absolute-level learned)
(let ((smt (SMT-on-conflict smt))) (let-values ([(pa total-vars-obliterated)
(obliterate-decision-levels!
(- (SMT-decision-level smt) absolute-level)
(SMT-partial-assignment smt)
0)])
(let* ((smt (SMT-set-decision-level smt absolute-level))
(smt (SMT-set-assigned-order smt (- (SMT-assigned-order smt) total-vars-obliterated)))
(smt (SMT-set-partial-assignment smt pa))
(smt (new-T-State smt ((T-Backjump) (SMT-T-State smt) total-vars-obliterated))))
(learned-bcp smt learned)))))
(define (obliterate-decision-levels! num-levels-to-obliterate partial-assignment [total-vars-obliterated 0])
(cond [(zero? num-levels-to-obliterate)
(values partial-assignment total-vars-obliterated)]
[else (obliterate-decision-levels!
(+ -1 num-levels-to-obliterate)
(rest partial-assignment)
(obliterate-decision-level! (first partial-assignment)
total-vars-obliterated))]))
(define (obliterate-decision-level! lits total-vars-obliterated)
(cond [(empty? lits) total-vars-obliterated]
[else
(begin (obliterate-lit! (first lits))
(obliterate-decision-level! (rest lits)
(+ 1 total-vars-obliterated)))]))
(define (obliterate-lit! lit)
(let ((var (literal-var lit)))
(begin (set-var-value! var 'unassigned)
(set-var-igraph-node! var #f)
(set-var-timestamp! var #f))))
(define (update-watchedness smt clause decided-literal)
(let ((caseval (bcp-4cases (some-clause->clause clause) decided-literal)))
(match caseval
['skip (values smt #f)] ['remove (values smt #t)]
['contradiction
(if (= 0 (SMT-decision-level smt)) (raise (unsat-exn smt))
(resolve-conflict! smt clause))]
[unit-literal (values (propagate-assignment smt unit-literal clause) #f)])))
(define (resolve-on-lit C D res-lit)
(define (memberf a B [proc equal?])
(and (not (empty? B))
(or (proc a (car B))
(memberf a (cdr B) proc))))
(define (list-union A B [proc equal?])
(cond [(empty? A) B]
[(memberf (first A) B proc)
(list-union (rest A) B proc)]
[else (cons (first A)
(list-union (rest A) B proc))]))
(list->vector (list-union (remove* (list (negate-literal res-lit) res-lit) (vector->list C))
(remove* (list (negate-literal res-lit) res-lit) (vector->list D))
literal-eq?)))
(define (resolve-conflict! smt C)
(let ((clause-lits (lemma->lits smt C)))
(let ((literals-to-learn
(let first-uip ((resolvent clause-lits))
(cond [(asserting-literals? smt resolvent)
resolvent]
[else (let* ((resolve-lit (choose-latest-literal resolvent))
(resolve-against (literal-explanation smt resolve-lit)))
(first-uip (resolve-on-lit resolvent resolve-against resolve-lit)))]))))
(let* ((level-to-backjump-to (asserting-level smt literals-to-learn))
(watch1 (choose-latest-literal literals-to-learn))
(watch2 (if (literal-eq? (vector-ref literals-to-learn 0) watch1)
(vector-ref literals-to-learn (+ -1 (vector-length literals-to-learn)))
(vector-ref literals-to-learn 0)))
(learned (clause literals-to-learn watch1 watch2)))
(begin (add-literal-watched! learned watch1)
(add-literal-watched! learned watch2)
(let-values ([(smt learned-clause) (SMT-learn-clause smt learned)])
(if (0 . > . level-to-backjump-to)
(raise (unsat-exn smt)) (raise (bail-exn (backjump! smt level-to-backjump-to learned-clause))))))))))
(define (asserting-literals? smt lits)
(not
(let not-asserting? ((idx 0)
(found-one? #f))
(and (not (= idx (vector-length lits))) (let* ((nthlit (vector-ref lits idx))
(dec-eq? (= (SMT-decision-level smt)
(literal-dec-lev nthlit))))
(or (and found-one? dec-eq?) (not-asserting? (+ 1 idx)
(or dec-eq? found-one?))))))))
(define (asserting-level smt lits)
(let recur ([idx 0] [candidate -1] [all-same-level? #t])
(if (= idx (vector-length lits))
(if all-same-level?
(+ -1 (literal-dec-lev (vector-ref lits 0)))
candidate)
(let* ((this-declev (literal-dec-lev (vector-ref lits idx)))
(same-level-as-last? (or (0 . > . candidate)
(= this-declev candidate)))
(all-same-level?* (and all-same-level?
same-level-as-last?)))
(if (and (this-declev . > . candidate)
(not (= this-declev (SMT-decision-level smt))))
(recur (+ 1 idx) this-declev all-same-level?*)
(recur (+ 1 idx) candidate all-same-level?*))))))
(define (bcp-4cases clause p [nonfalse-literal #f] [multiple #f] [idx 0])
(cond [(= idx (clause-size clause))
(if multiple (begin (add-literal-watched! clause nonfalse-literal) (clause-watched-swap! clause p nonfalse-literal)
'remove) (if nonfalse-literal
(if (literal-unassigned? nonfalse-literal)
nonfalse-literal 'skip) 'contradiction))] [else
(let* ((literal (nth-literal clause idx))
(litval (literal-valuation literal)))
(if (or (literal-eq? literal p) (false? litval)) (bcp-4cases clause p nonfalse-literal multiple (+ 1 idx))
(if (and (literal-eq? (clause-other-watched clause p)
literal) nonfalse-literal) (bcp-4cases clause p nonfalse-literal #t (+ 1 idx))
(bcp-4cases clause p literal nonfalse-literal (+ 1 idx)))))]))