import time
from z3 import *

from smt.solver import Solver, Model

class Z3Solver(Solver):

  def __init__(self, mode="opt"):
    if mode == "opt":
      self.ctx = Optimize()
    elif mode == "qe":
      self.ctx = Then('qe', 'smt').solver()
    else:
      self.ctx = Solver()
    self._checks = 0
    self._check_time = 0
    self._simp_time = 0

  def destroy(self):
    del self.ctx

  def are_equal_expr(self, a, b):
    return hash(a) == hash(b)
  
  def true(self):
    return True
  
  def false(self):
    return False
  
  # integer constants
  def num(self, n):
    return IntVal(n)
  
  # real constants
  def real(self, n):
    return RealVal(n)
  
  # boolean constant with name
  def boolconst(self, n):
    return Bool(n)
  
  # integer constant with name
  def intconst(self, n):
    return Int(n)
  
  # real constant with name
  def realconst(self, n):
    return Real(n)
  
  # boolean variable with name (might be used for quantification)
  def boolvar(self, n):
    return Bool(n)
  
  # integer variable with name (might be used for quantification)
  def intvar(self, n):
    return Int(n)
  
  # real variable with name (might be used for quantification)
  def realvar(self, n):
    return Real(n)
  
  # logical conjunction
  def land(self, l):
    if len(l) == 1:
      return l[0]
    else:
      return And(l)

  # logical disjunction
  def lor(self, l):
    if len(l) == 1:
      return l[0]
    else:
      return Or(l)

  # logical negation
  def neg(self, a):
    if self.is_false(a):
      return self.true()
    elif self.is_true(a):
      return self.false()
    elif self.is_not(a):
      return self.arg(a,0)
    return Not(a)

  # logical implication
  def implies(self, a, b):
    return Implies(a, b)

  # logical biimplication
  def iff(self, a, b):
    return Not(Xor(a, b))

  # equality of arithmetic terms
  def eq(self, a, b):
    return a == b

  # inequality of arithmetic terms
  def neq(self, a, b):
    return a != b

  # less-than on arithmetic terms
  def lt(self, a, b):
    return a < b

  # less-than-or-equal on arithmetic terms
  def le(self, a, b):
    return a <= b

  # greater-or-equal on arithmetic terms
  def ge(self, a, b):
    return a >= b

  # greater-than on arithmetic terms
  def gt(self, a, b):
    return a > b

  # increment arithmetic term by 1
  def inc(self, a):
    return a + 1
  
  # subtraction
  def minus(self, a, b):
    return a - b

  # addition
  def plus(self, a, b):
    return a + b

  # multiplication
  def mult(self, a, b):
    return a * b

  # if-then-else
  def ite(self, cond, a, b):
    return If(cond, a, b)
  
  # term inspection
  def num_args(self, e):
    return e.num_args()

  def arg(self, e, i):
    return e.arg(i)

  def is_true(self, e):
    return is_true(e)

  def is_false(self, e):
    return is_false(e)

  def is_int(self, e):
    return is_int_value(e)

  def is_real(self, e):
    return is_rational_value(e)

  def numeric_value(self, e):
    return int(str(e)) if is_int_value(e) else float(str(e))

  def is_var(self, e):
    return is_const(e)

  def is_not(self, e):
    return is_not(e)

  def is_and(self, e):
    return is_and(e)

  def is_or(self, e):
    return is_or(e)

  def is_eq(self, e):
    return is_eq(e)

  def is_le(self, e):
    return is_le(e)

  def is_lt(self, e):
    return is_lt(e)

  def is_ge(self, e):
    return is_ge(e)

  def is_gt(self, e):
    return is_gt(e)

  def is_plus(self, e):
    return is_add(e)

  def is_minus(self, e):
    return is_sub(e)

  def is_mult(self, e):
    return is_mul(e)

  def is_forall(self, e):
    return is_quantifier(e)

  def exists(self, xs, e):
    return Exists(xs,e)

  def subst(self, vars, terms, e):
    ps = list(zip(vars, terms))
    return substitute(e, ps)

  def simplify_more(self, e):
    start = time.time()
    t = Then('simplify', 'nnf', 'propagate-ineqs', 'ctx-solver-simplify')
    dnf = t(e)
    self._simp_time += time.time() - start
    return Or([ And([d for d in conj]) for conj in dnf])
  
  def simplify(self, e):
    start = time.time()
    ee = simplify(e)
    self._simp_time += time.time() - start
    return ee

  def qe_simp(self, e):
    start = time.time()
    t1 = Tactic('qe')
    t2 = Tactic('simplify')
    t  = Then(t1, t2)
    dnf = t(e)
    dsimps = [d.as_expr() for d in dnf]
    self._simp_time += time.time() - start
    return Or(dsimps) if len(dsimps) > 1 else dsimps[0]

  def equivalent(self, a, b):
    if self.simplify(a) == self.simplify(b):
      return True
    self.push()
    self.require(Xor(a,b))
    status = self.ctx.check()
    res = (status != z3.sat)
    self.pop()
    return res

  def push(self):
    self.ctx.push()

  def pop(self):
    self.ctx.pop()

  # add list of assertions
  def require(self, formulas):
    self.ctx.add(formulas)

  # check satisfiability: return True if sat, False otherwise
  def check_sat(self, e, eval = None):
    start = time.time()
    self.push()
    self.require(e)
    res = Z3Model(self.ctx) if self.ctx.check() == z3.sat else None
    self.pop()
    self._checks += 1
    self._check_time += time.time() - start
    return res

  # minimize given expression
  def minimize(self, expr, max_val):
    val = self.ctx.minimize(expr)
    t_start = time.perf_counter()
    result = self.ctx.check()
    self.t_solve = time.perf_counter() - t_start
    return Z3Model(self.ctx) if result == z3.sat else None

  # reset context
  def reset(self):
    self.ctx = Optimize() # Optimize solver does not have reset function
    self.t_solve = 0
  
  def to_string(self, t):
    return str(t)


class Z3Model(Model):

  def __init__(self, ctx):
    self.model = ctx.model()
  
  def eval_bool(self, v):
    return self.model.eval(v)
  
  def eval_int(self, v):
    if isinstance(v, int):
      return v
    if v.eq(v):
      return 0
    return self.model.eval(v).as_long()
  
  def eval_real(self, v):
    if isinstance(v, int) or isinstance(v, float):
      return v
    if v.eq(v):
      return 0
    return float(self.model.eval(v).as_fraction())
  
  def destroy(self):
    pass
  