from smt.solver import *
from smt.cvc5solver import *

def is_lit(solver, lit):
  if solver.is_not(lit):
    return is_lit(solver, solver.arg(lit, 0))
  else:
    return solver.is_le(lit) or solver.is_ge(lit) or solver.is_lt(lit) or \
      solver.is_gt(lit) or solver.is_eq(lit)


class AbstractionEquivalence:
  def __init__(self, solver):
    pass

  def same(self, solver, e1, e2):
    pass

  # some simplifications of literal conjunctions, mainly for nicer output
  def simp_lit(self, l):
    s = self._solver
    def simp_term(t):
      if s.is_plus(t):
        r = s.arg(t, 1)
        if s.is_mult(r) and str(s.arg(r, 0)) == "-1":
          return s.minus(s.arg(t, 0), s.arg(r, 1))
      return t

    if s.is_not(l):
      a = s.arg(l, 0)
      if s.is_le(a): # not (c <= d)  to  (c > d)
        return s.gt(simp_term(s.arg(a, 0)), simp_term(s.arg(a, 1)))
      if s.is_ge(a): # not (c >= d)  to  (c < d)
        return s.lt(simp_term(s.arg(a, 0)), simp_term(s.arg(a, 1)))
    # 0 > (c - d)  to  d > c 
    elif s.is_gt(l) and s.is_minus(s.arg(l,1)) and str(s.arg(l,0))=="0":
      return s.gt(s.arg(s.arg(l,1), 1), s.arg(s.arg(l,1), 0))
    elif s.is_le(l):
      return s.le(simp_term(s.arg(l, 0)), simp_term(s.arg(l, 1)))
    elif s.is_or(l) and s.num_args(l) == 2:
      a = self.simp_lit(s.arg(l,0))
      b = self.simp_lit(s.arg(l,1))
      if s.num_args(a) == 2 and s.num_args(b) == 2:
        compl_op = (s.is_lt(a) and s.is_gt(b)) or (s.is_gt(a) and s.is_lt(b))
        same_arg = s.are_equal_expr(s.arg(a,0), s.arg(b,0)) and \
          s.are_equal_expr(s.arg(a,1), s.arg(b, 1))
        if (compl_op and same_arg):
          return s.neq(s.arg(a,0), s.arg(a,1))
        else:
          return s.lor([a,b])

    return l

  # Z3 produces (c >= d && d >= c) instead of c == d, replace for readability
  def combine_ge_le(self, args):
    s = self._solver
    def combineable_to_equality(a,b):
      compl_op = (s.is_le(a) and s.is_ge(b)) or (s.is_ge(a) and s.is_le(b))
      same_op = (s.is_le(a) and s.is_le(b)) or (s.is_ge(a) and s.is_ge(b))
      compl_arg = s.are_equal_expr(s.arg(a,0), s.arg(b,1)) and \
        s.are_equal_expr(s.arg(a,1), s.arg(b,0))
      same_arg = s.are_equal_expr(s.arg(a,0), s.arg(b,0)) and \
        s.are_equal_expr(s.arg(a,1), s.arg(b, 1))
      return (compl_op and same_arg) or (same_op and compl_arg)

    for i in range(0, len(args)):
      if s.is_le(args[i]) or s.is_ge(args[i]):
        for j in range(0, len(args)):
          if not (s.is_le(args[j]) or s.is_ge(args[j])):
            continue
          if combineable_to_equality(args[i], args[j]):
            xargs = [args[k] for k in range(0,len(args)) if k != i and k != j] \
              + [s.eq(args[i].arg(0), args[i].arg(1))]
            return self.combine_ge_le(xargs)
    return args

  def simp_lit_conjunction(self, e):
    s = self._solver
    if s.is_and(e):
      args = [ self.simp_lit(s.arg(e, i)) for i in range(0, s.num_args(e)) ]
      argss = self.combine_ge_le(args)
      return s.land(argss) if len(argss) > 1 else argss[0] if len(argss) == 1 \
        else s.true()
    else:
      return self.simp_lit(e)

  def simp_dnf(self, e):
    s = self._solver
    if s.is_or(e):
      args = [ self.simp_lit_conjunction(s.arg(e, i)) for i in range(0, s.num_args(e)) ]
      return s.lor(args) if len(args) > 1 else args[0] if len(args) == 1 \
        else s.false()
    else:
      return self.simp_lit(e)


class Equivalence(AbstractionEquivalence):
  def __init__(self, solver):
    self._solver = solver

  def same(self, e1, e2):
    if e1 == e2:
      return True
    #if conjuncts(e1) == conjuncts(e2):
    #  return True
    s = self._solver
    #not_eq_expr = s.neg(s.iff(e1, e2))
    not_eq_expr = s.lor([s.land([e1, s.neg(e2)]), s.land([e2, s.neg(e1)])])
    return s.check_sat(not_eq_expr) == None # no model


class GCEquivalence(AbstractionEquivalence):
  def __init__(self, solver, cutoff):
    self._cutoff = cutoff
    self._solver = solver

  def same(self, e1, e2):
    if e1 == e2:
      return True
    s = self._solver
    e1c = self.cutoff(e1)
    e2c = self.cutoff(e2)
    not_eq_expr = s.lor([s.land([e1c, s.neg(e2c)]), s.land([e2c, s.neg(e1c)])])
    return self._solver.check_sat(not_eq_expr) == None
  
  def cutoff(self, e):
    slv = self._solver
    if slv.is_true(e) or slv.is_false(e):
      return e

    def cutoff_term(t):
      if slv.is_numeric(t):
        n = slv.numeric_value(t)
        cutval = min(n, self._cutoff) if n > 0 else max(n, self._cutoff * (-1))
        return slv.num(cutval) if slv.is_int(t) else slv.real(cutval)
      elif slv.is_var(t):
        return t
      else:
        assert(slv.num_args(t) == 2)
        l = slv.arg(t, 0)
        r = slv.arg(t, 1)
        if slv.is_minus(t):
          return slv.minus(cutoff_term(l), cutoff_term(r))
        elif slv.is_plus(t):
          return slv.plus(cutoff_term(l), cutoff_term(r))
        elif slv.is_mult(t):
          return slv.mult(cutoff_term(l), cutoff_term(r))
      return t
    
    def cutoff_lit(lit):
      l = slv.arg(lit, 0)
      r = slv.arg(lit, 1)
      if slv.is_le(lit):
        return slv.le(cutoff_term(l), cutoff_term(r))
      elif slv.is_ge(lit):
        return slv.ge(cutoff_term(l), cutoff_term(r))
      elif slv.is_lt(lit):
        return slv.lt(cutoff_term(l), cutoff_term(r))
      elif slv.is_gt(lit):
        return slv.gt(cutoff_term(l), cutoff_term(r))
      elif slv.is_eq(lit):
        return slv.eq(cutoff_term(l), cutoff_term(r))
      return lit

    if slv.is_not(e):
      return slv.neg(self.cutoff(slv.arg(e, 0)))
    if is_lit(slv, e):
      return cutoff_lit(e)
    if slv.is_forall(e):
      res = slv.forall(slv.arg(e, 0), self.cutoff(slv.arg(e, 1)))
      return res
    cutargs = [ self.cutoff(slv.arg(e, i)) for i in range(0, slv.num_args(e)) ]
    if slv.is_or(e):
      return slv.simplify(slv.lor(cutargs))
    if slv.is_and(e):
      return slv.simplify(slv.land(cutargs))

    print(e)
    assert(False) # unexpected shape
    return e


class DecompositionEquivalence(AbstractionEquivalence):
  def __init__(self, solver, parts, equivs):
    self._solver = solver
    self._equivs = equivs
    self._parts = parts

  def project(self, e, vs):
    slv = self._solver
    # project SMT expression over variables V to variable set vs, a subset of V
    def term_has_vs(t):
      if slv.is_numeric(t):
        return False
      elif slv.is_var(t):
        return str(t) in vs
      else:
        return any(term_has_vs(slv.arg(t,i)) for i in range(0, slv.num_args(t)))
    
    if not slv.is_and(e):
      print(e)
      assert(False)
      
    res = [ slv.arg(e, i) for i in range(0, slv.num_args(e)) \
      if term_has_vs(slv.arg(e,i)) ]
    res = slv.simplify(slv.land(res))
    return res

  def same(self, e1, e2):
    if e1 == e2:
      return True
    res = all( eq.same(self.project(e1, vs), self.project(e2, vs)) \
      for (vs, eq) in zip(self._parts, self._equivs))
    #if res:
    #  print("same: " + str(e1) + " and " + str(e2))
    #  for (vs, eq) in zip(self._parts, self._equivs):
    #    print(vs)
    #    print(self.project(e1, vs), self.project(e2, vs))
    return res
