#lang racket
(require "bcp.rkt")
(require "data-structures.rkt")
(require "smt-interface.rkt")
(require "dimacs.rkt")
(require "debug.rkt")
(require rackunit)
(provide smt-solve
smt-assign
smt-decide
sat-solve
sat-assign
sat-decide)
(define (vars->clause-of-not-and vars)
(let* ((negate-assigned
(map (lambda (v)
(literal v (not (var-value v))))
vars))
(clause-list (list->vector negate-assigned)))
(clause clause-list
(vector-ref clause-list 0)
(vector-ref clause-list (+ -1 (vector-length clause-list))))))
(define (not-partial-assignment smt)
(vars->clause-of-not-and (filter-not var-unassigned?
(vector->list (SMT-variables smt)))))
(define (get-T-solver-blessing smt)
(let ((consistent ((T-Consistent?) (SMT-T-State smt) +inf.0)))
(match consistent
[#t (raise (sat-exn smt))] [#f (debug "inconsistent" (not-partial-assignment smt))
(resolve-conflict! smt (not-partial-assignment smt))]
[explanation (resolve-conflict! smt (vars->clause-of-not-and explanation))])))
(define (decide smt literal)
(let* ((smt (SMT-set-decision-level smt (+ 1 (SMT-decision-level smt))))
(smt (SMT-set-partial-assignment smt (cons '() (SMT-partial-assignment smt)))))
(propagate-assignment smt literal #f)))
(define (choose-in-order smt)
(let ((vars (SMT-variables smt)))
(let keep-looking ((idx 0))
(if (= idx (vector-length vars))
(get-T-solver-blessing smt)
(let ((var (vector-ref vars idx)))
(if (var-unassigned? var)
(literal var #t)
(keep-looking (+ 1 idx))))))))
(define (vsids smt)
(let ((vars (SMT-variables smt)))
(let keep-looking ((idx 0) (candidate #f) (best -1))
(if (idx . < . (vector-length vars))
(let* ((var (vector-ref vars idx))
(score (+ (var-pos-activation var)
(var-neg-activation var))))
(if (and (var-unassigned? var)
(score . > . best))
(keep-looking (+ 1 idx) var score)
(keep-looking (+ 1 idx) candidate best)))
(if (not candidate)
(get-T-solver-blessing smt)
(literal candidate
((var-pos-activation candidate) . >= . (var-neg-activation candidate))))))))
(define (smt-search smt [choose-literal vsids])
(let keep-solving ((smt (initial-bcp smt)))
(keep-solving
(with-handlers
([bail-exn? (lambda (x) (bail-exn-sat x))])
(decide smt (choose-literal smt))))))
(define (smt-solve cnf t-state strength [choose-literal vsids])
(with-handlers
([sat-exn? (lambda (x) (debug "SAT" (SMT-partial-assignment (sat-exn-sat x))) x)]
[unsat-exn? (lambda (x) (debug "UNSAT" (SMT-partial-assignment (unsat-exn-sat x))) x)])
(smt-search (initialize cnf t-state strength) choose-literal)))
(define (sat-solve cnf [choose-literal vsids])
(parameterize ([T-Satisfy sat-satisfy]
[T-Propagate sat-propagate]
[T-Explain sat-explain]
[T-Consistent? sat-consistent?]
[T-Backjump sat-backjump])
(smt-solve cnf #f choose-literal)))
(define (smt-decide cnf t-state strength [choose-literal vsids])
(match (smt-solve cnf t-state choose-literal)
[(? sat-exn? smt) 'SAT]
[(? unsat-exn? smt) 'UNSAT]))
(define (sat-decide cnf [choose-literal vsids])
(match (sat-solve cnf choose-literal)
[(? sat-exn? smt) 'SAT]
[(? unsat-exn? smt) 'UNSAT]))
(define (extract-public-partial-assignment smt)
(filter-map (λ (v) (and (not (var-unassigned? v))
(var->dimacs v)))
(vector->list (SMT-variables (sat-exn-sat smt)))))
(define (smt-assign cnf t-state strength [choose-literal vsids])
(match (smt-solve cnf t-state strength choose-literal)
[(? sat-exn? smt) (extract-public-partial-assignment smt)]
[(? unsat-exn? smt) 'UNSAT]))
(define (sat-assign cnf [choose-literal vsids])
(match (sat-solve cnf choose-literal)
[(? sat-exn? smt) (extract-public-partial-assignment smt)]
[(? unsat-exn? smt) 'UNSAT]))
(check equal?
(sat-decide '(1 2 ((1) (-1))))
'UNSAT)
(check equal?
(sat-decide '(5 5 ((-1 2) (-1 3) (-2 4) (-3 -4) (1 -3 5))))
'SAT)
(check equal?
(sat-decide '(6 7 ((1 2) (2 3) (-1 -4 5) (-1 4 6) (-1 -5 6) (-1 4 -6) (-1 -5 -6))))
'SAT)