from abc import ABCMeta, abstractmethod
from functools import reduce

from dds.util import VarType

class Expr:
  __metaclass__ = ABCMeta

  strmap = {}

  @abstractmethod
  def toSMT(self, solver, subst):
    pass
  
  def vars(self):
    return set([])

  def basevars(self):
    return set([])

  @staticmethod
  def numval(s):
    try:
      val = float(s)
    except Exception:
      if s.lower() == "true" or s.lower() == "false":
        val = 1 if s.lower() == "true" else 0
      elif isinstance(s,str):
        if s in Expr.strmap:
          val = Expr.strmap[s]
        else:
          val = len(Expr.strmap)
          Expr.strmap[s] = val
    return val

  @staticmethod
  def strval(numval):
    strs = [s for (s,n) in Expr.strmap.items() if n == numval]
    return strs[0] if len(strs) > 0 else "?"

  def is_mc(self):
    return False

class Term(Expr):
  def __init__(self):
    pass

  def toSMT(self, solver, subst):
    pass

class NotFound(Exception):
  pass

class Var(Term):
  def __init__(self, c, vtype, prime=None):
    self.name = c
    self.is_prime = (prime != None)
    self.is_back = False
    self._type = vtype
    #print("create var " + c + " " + str(vtype) + " " + str(prime) + " ... " + str(self))
  
  def __eq__(self, obj):
      return isinstance(obj, Var) and obj.name == self.name and \
        obj.is_prime == self.is_prime and obj.is_back == self.is_back
  
  def __hash__(self):
      return hash((self.__class__.__name__, self.name, self.is_prime, \
        self.is_back))
  
  def __str__(self):
    suffix = "'" if self.is_prime else "-" if self.is_back else ""
    return self.name + suffix

  def accept(self, visitor):
    visitor.visit_var(self)

  def set_type(self, t):
    assert (t == self._type)
  
  def basename(self):
    return self.name

  def toSMT(self, solver, subst):
    if not (str(self) in subst):
      print(str(self) + " not found in subst:", subst) 
    assert(str(self) in subst)
    return subst[str(self)]

  def vars(self):
    return { self }

  def basevars(self):
    return set([self.basename()])

  def value(self, subst):
    if str(self) in subst:
      return subst[str(self)]
    else:
      raise NotFound

  @staticmethod
  def from_array(v):
    return Var(v["name"], v["type"])

class PropVar(Term):
  def __init__(self, c):
    self.name = c
    self._type = VarType.bool
  
  def __eq__(self, obj):
      return isinstance(obj, PropVar) and obj.name == self.name
  
  def __hash__(self):
      return hash((self.__class__.__name__, self.name))

  def accept(self, visitor):
    visitor.visit_propvar(self)

  def set_type(self, t):
    assert (t == VarType.bool)
  
  def basename(self):
    return self.name

  def negate(self):
    return UnCon("!", self)
  
  def __str__(self):
    return self.name

  def toSMT(self, solver, subst):
    return subst[str(self)] if str(self) in subst else solver.boolconst(str(self))


class Num(Term):
  def __init__(self, c):
    self.num = c
    self._type = None
  
  def __eq__(self, obj):
      return isinstance(obj, Num) and obj.num == self.num
  
  def __hash__(self):
      return hash((self.__class__.__name__, self.num))
  
  def __str__(self):
    return str(self.num)

  def accept(self, visitor):
    visitor.visit_num(self)

  def set_type(self, t):
    assert (self._type == None or t == self._type)
    self._type = t

  def toSMT(self, slv, subst):
    # print(self.num, self._type)
    val = float(self.num) # may be string
    res = slv.num(int(val)) if self._type == VarType.int else slv.real(val)
    return res

  def value(self, subst):
    return self.num


class Charstr(Term):

  def __init__(self, c):
    self.chr = c
    self._type = VarType.int
  
  def __eq__(self, obj):
      return isinstance(obj, Charstr) and obj.chr == self.chr
  
  def __hash__(self):
      return hash((self.__class__.__name__, self.chr))
  
  def __str__(self):
    return str(Expr.numval(self.chr)) # '"' + self.chr + '"'

  def accept(self, visitor):
    visitor.visit_char(self)

  def set_type(self, t):
    assert (t == VarType.int)

  def toSMT(self, solver, subst):
    return solver.num(Expr.numval(self.chr))

  def value(self, subst):
    return Expr.numval(self.chr)


class Bool(Term):
  strmap = {}

  def __init__(self, c):
    assert(isinstance(c,bool) or c.lower() in ["true", "false"])
    self.val = (c.lower() == "true") if isinstance(c, str) else c
    self._type = VarType.bool
  
  def __eq__(self, obj):
      return isinstance(obj, Bool) and obj.val == self.val
  
  def __hash__(self):
      return hash((self.__class__.__name__, self.val))
  
  def __str__(self):
    return 'true' if self.val else 'false'

  def accept(self, visitor):
    visitor.visit_bool(self)

  def negate(self):
    return bot if self == top else top

  def set_type(self, t):
    assert (t == VarType.bool)

  def toSMT(self, solver, subst):
    return solver.true() if self.val else solver.neg(solver.true())

  def value(self, subst):
    return Expr.numval(self.chr)

  def is_mc(self):
    return True


class BinOp(Term):

  def __init__(self, a, op, b):
    self.op = op
    assert(op in ["+", "-", "*"])
    self.left = a
    self.right = b
    if self.left._type != None:
      self.set_type(self.left._type)
    elif self.right._type != None:
      self.set_type(self.right._type)
    else:
      self._type = None
  
  def __eq__(self, obj):
    return isinstance(obj, BinOp) and obj.op == self.op and \
      obj.left == self.left and obj.right == self.right
  
  def __hash__(self):
    return hash((self.__class__.__name__, self.op, hash(self.left), \
      hash(self.right)))
  
  def __str__(self):
    return "(" + str(self.left) + " " + self.op + " " +\
		       str(self.right) + ")"

  def accept(self, visitor):
    stop = visitor.visit_binop(self)
    if not (stop == visitor.STOP_RECURSION):
      self.left.accept(visitor)
      self.right.accept(visitor)

  def set_type(self, t):
    self._type = t
    self.right.set_type(t)
    self.left.set_type(t)

  def toSMT(self, solver, subst):
    op_funs = {
      "+"  : lambda a, b: solver.plus(a, b),
      "-" : lambda a, b: solver.minus(a, b),
      "*" : lambda a, b: solver.mult(a, b)
    }
    (l, r) = (self.left.toSMT(solver, subst), self.right.toSMT(solver, subst))
    return op_funs[self.op](l, r)

  def vars(self):
    return self.left.vars().union(self.right.vars())

  def basevars(self):
    return self.left.basevars().union(self.right.basevars())

  def value(self, subst):
    if self.op == "+":
      return self.left.value(subst) + self.right.value(subst)
    else:
      return self.left.value(subst) - self.right.value(subst)


class Cmp(Term):

  def __init__(self, op, a, b):
    self.op = op
    assert(op in ["==", ">=", "<=", "<", ">", "!="])
    self.left = a
    self.right = b
    self._type = VarType.bool
    if self.left._type != None:
      self.right.set_type(self.left._type)
    elif self.right._type != None:
      self.left.set_type(self.right._type)
  
  def __eq__(self, obj):
    return isinstance(obj, Cmp) and obj.op == self.op and \
      obj.left == self.left and obj.right == self.right
  
  def __hash__(self):
    return hash((self.__class__.__name__, self.op, hash(self.left), \
      hash(self.right)))

  def __str__(self):
    return "(" + str(self.left) + " " + self.op + " " +\
		       str(self.right) + ")"

  def accept(self, visitor):
    stop = visitor.visit_cmp(self)
    if not (stop == visitor.STOP_RECURSION):
      self.left.accept(visitor)
      self.right.accept(visitor)

  def is_mc(self):
    if (isinstance(self.left, Var) or isinstance(self.left, Num)) and \
      (isinstance(self.right, Var) or isinstance(self.right, Num)):
      return True
    return False

  def toSMT(self, solver, subst):
    if self.op in ["==","!="] and self.left._type == VarType.bool:
      inv = lambda x: solver.neg(x) if self.op == "!=" else x
      if isinstance(self.left, Bool):
        v = self.right.toSMT(solver, subst)
        return inv(v if self.left.val else solver.neg(v))
      elif isinstance(self.right, Bool):
        v = self.left.toSMT(solver, subst)
        return inv(v if self.right.val else solver.neg(v))

    op_funs = {
      "=="  : lambda a, b: solver.eq(a, b),
      ">=" : lambda a, b: solver.ge(a, b),
      "<=" : lambda a, b: solver.ge(b, a),
      ">"  : lambda a, b: solver.lt(b, a),
      "<"  : lambda a, b: solver.lt(a, b),
      "!=" : lambda a, b: solver.neg(solver.eq(a, b)),
    }
    (l, r) = (self.left.toSMT(solver, subst), self.right.toSMT(solver, subst))
    return op_funs[self.op](l, r)
  

  def negate(self):
    flip_op = { "<":">=", ">":"<=", ">=":"<", "<=":">", "==":"!=", "!=":"=="}
    return Cmp(flip_op[self.op], self.left, self.right)

  def vars(self):
    return self.left.vars().union(self.right.vars())

  def basevars(self):
    return self.left.basevars().union(self.right.basevars())
  
  def comparisons(self):
    return set([self])

  def valid(self, subst):
    try:
      if self.op == "==":
        return self.left.value(subst) == self.right.value(subst)
      elif self.op == ">=":
        return self.left.value(subst) >= self.right.value(subst)
      elif self.op == "<=":
        return self.left.value(subst) <= self.right.value(subst)
      elif self.op == ">":
        return self.left.value(subst) > self.right.value(subst)
      elif self.op == "<":
        return self.left.value(subst) < self.right.value(subst)
      else:
        return self.left.value(subst) != self.right.value(subst)
    except NotFound: # variable not found
      return False


class UnCon(Term):

  def __init__(self, op, a):
    self.op = op
    self.arg = a
  
  def __eq__(self, obj):
    return isinstance(obj,UnCon) and obj.op == self.op and obj.arg == self.arg
  
  def __hash__(self):
    return hash((self.__class__.__name__, self.op, hash(self.arg)))
  
  def __str__(self):
    return self.op + " " + str(self.arg)

  def accept(self, visitor):
    stop = visitor.visit_uncon(self)
    if not (stop == visitor.STOP_RECURSION):
      self.arg.accept(visitor)

  def negate(self):
    if self.op == "!":
      return self.arg
    elif self.op == "F":
      return UnCon("G", self.arg.negate())
    elif self.op == "G":
      return UnCon("F", self.arg.negate())
    elif self.op == "E":
      return UnCon("A", self.arg.negate())
    elif self.op == "A":
      return UnCon("E", self.arg.negate())
    elif self.op == "X":
      return UnCon("Xw", self.arg.negate())
    else:
      assert(self.op == "Xw")
      return UnCon("X", self.arg.negate())

  def set_type(t):
    self._type = t

  def toSMT(self, solver, subst):
    return solver.neg(self.arg.toSMT(solver, subst))

  def vars(self):
    return self.arg.vars()

  def basevars(self):
    return self.arg.basevars()

  def comparisons(self):
    return self.left.comparisons().union(self.right.comparisons())


class BinCon(Term):

  def __init__(self, a, op, b):
    self.op = op
    self.left = a
    self.right = b
    self._type = VarType.bool
  
  def __eq__(self, obj):
    return isinstance(obj, BinCon) and obj.op == self.op and \
      obj.left == self.left and obj.right == self.right
  
  def __hash__(self):
    return hash((self.__class__.__name__, self.op, hash(self.left), \
      hash(self.right)))
  
  def __str__(self):
    return "(" + str(self.left) + " " + self.op + " " +\
		       str(self.right) + ")"

  def accept(self, visitor):
    stop = visitor.visit_bincon(self)
    if not (stop == visitor.STOP_RECURSION):
      self.left.accept(visitor)
      self.right.accept(visitor)
  
  def negate(self):
    left = self.left.negate()
    right = self.right.negate()
    if self.op == "&&": # but with negation is ||
      if left == bot:
        return right
      elif right == bot:
        return left
      elif left == top or right == top:
        return top
      return BinCon(left, "||", right)
    elif self.op == "||":
      if left == top:
        return right
      elif right == top:
        return left
      elif left == bot or right == bot:
        return bot
      return BinCon(left, "&&", right)
    else:
      assert(self.op == "U")
      neg_until = BinCon(right, "U", left)
      return BinCon(neg_until, "||", UnCon("G", right))

  def toSMT(self, solver, subst):
    assert(self.op == "&&" or self.op == "||")
    op_funs = {
      "&&"  : lambda a, b: solver.land([a, b]),
      "||" : lambda a, b: solver.lor([a, b]),
    }
    (l, r) = (self.left.toSMT(solver, subst), self.right.toSMT(solver, subst))
    return op_funs[self.op](l, r)

  def vars(self):
    return self.left.vars().union(self.right.vars())

  def basevars(self):
    return self.left.basevars().union(self.right.basevars())

  def comparisons(self):
    return self.left.comparisons().union(self.right.comparisons())

  def valid(self, subst):
    try:
      if self.op == "&&":
        return self.left.valid(subst) and self.right.valid(subst)
      else:
        return self.left.valid(subst) or self.right.valid(subst)
    except NotFound: # variable not found
      return False
  
  def is_mc(self):
    if self.op == "&&":
      return self.left.is_mc() and self.right.is_mc()
    return False

# additional base case in LTL formulas for CTL*
class ConfigMapAtom(Term):
  def __init__(self, m):
    self._map = m
    self._type = VarType.bool
  
  def __eq__(self, obj):
      return isinstance(obj, ConfigMapAtom) and obj._map == self._map

  def accept(self, visitor):
    visitor.visit_configmap(self)

  def set_type(self, t):
    assert (t == VarType.bool)
  
  def negate(self):
    return ConfigMapAtom(self._map.negate())
  
  def lor(self, other):
    assert(isinstance(other, ConfigMapAtom))
    return ConfigMapAtom(self._map.lor(other._map))
  
  def land(self, other):
    assert(isinstance(other, ConfigMapAtom))
    return ConfigMapAtom(self._map.land(other._map))
  
  def __str__(self):
    return str(self._map)
  
  def __hash__(self):
      return hash((self.__class__.__name__, self._map))



top = Bool(True)

bot = Bool(False)


# visitor
class Visitor:
    STOP_RECURSION = True
    @abstractmethod
    def visit_var(self, element):
        pass

    @abstractmethod
    def visit_num(self, element):
        pass
    
    @abstractmethod
    def visit_propvar(self, element):
        pass

    @abstractmethod
    def visit_bool(self, element):
        pass

    @abstractmethod
    def visit_char(self, element):
        pass

    @abstractmethod
    def visit_unop(self, element):
        pass

    @abstractmethod
    def visit_binop(self, element):
        pass

    @abstractmethod
    def visit_cmp(self, element):
        pass

    @abstractmethod
    def visit_bincon(self, element):
        pass

    @abstractmethod
    def visit_uncon(self, element):
        pass
  

class VarFlipper(Visitor):
  def __init__(self):
    pass

  def visit_var(self, v):
    v.is_prime = not (v.is_prime)


def has_propvars(formula, ps):
  def has_ps(f):
    if isinstance(f, PropVar) and f.name in ps:
      return True
    elif isinstance(f, UnCon):
      return has_ps(f.arg)
    elif isinstance(f, BinCon):
      return has_ps(f.left) or has_ps(f.right)
    return False
  
  return has_ps(formula)

def nnf(f):
  if isinstance(f, PropVar) or isinstance(f, Cmp) or isinstance(f, Bool) \
    or isinstance(f, ConfigMapAtom):
    return f
  elif isinstance(f, BinCon):
    left = nnf(f.left)
    right = nnf(f.right)
    return BinCon(left, f.op, right)
  elif isinstance(f, UnCon) and f.op != "!":
    return UnCon(f.op, nnf(f.arg))
  elif isinstance(f, UnCon): # negation
    return f.arg.negate()

def mk_var(solver, v, suffix = ""):
  if not isinstance(v, Var):
    v = Var.from_array(v)
  name = str(v) + suffix
  if v._type == VarType.bool:
    return solver.boolvar(name)
  elif v._type == VarType.int:
    return solver.intvar(name)
  else:
    return solver.realvar(name)

# in CVC5, stuff that gets evaluated must be a const instead of a var
def mk_const(solver, v, suffix = ""):
  if not isinstance(v, Var):
    v = Var.from_array(v)
  name = str(v) + suffix
  if v._type == VarType.bool:
    return solver.boolconst(name)
  elif v._type == VarType.int:
    return solver.intconst(name)
  else:
    return solver.realconst(name)

def mk_conj(cs):
  return reduce(lambda c, a: BinCon(c, "&&", a) if c else a,cs, top)

def mk_disj(cs):
  return reduce(lambda c, a: BinCon(c, "||", a) if c else a,cs, bot)

def rename_and_quantify(solver, formulas, vs, smtvs, smtvs0):
  # vs and vs0 are already SMT variables sorted in correct way
  vs_primed = [ mk_var(solver, v, suffix = "'") for v in vs ]
  expr = solver.lor(formulas)
  if len(smtvs) > 0:
    expr = solver.subst(smtvs, vs_primed, expr) # V to U
    expr = solver.subst(smtvs0, smtvs, expr) # V0 to V
    expr = solver.exists(vs_primed, expr)
  return solver.qe_simp(expr)