from smt.z3solver import *
from smt.cvc5solver import *
import warnings

from dds.expr import Cmp, PropVar, Bool, BinCon, UnCon, ConfigMapAtom, top,Var,\
  has_propvars, nnf, mk_const, mk_var, rename_and_quantify
from verification.product_construction import ProductConstruction
from verification.abstraction_equivalence import Equivalence
from verification.configuration_mapping import ConfigMap
from ltl.automaton import Automaton
from ltl.property import LTLProperty

class Checker:

  def __init__(self, dds, equivalence, output_suffix, verbose):
    self._dds = dds
    self._equivalence = equivalence
    solver = equivalence._solver
    self._solver = solver
    self._output_suffix = output_suffix
    self._smt_vars = dict([ (v["name"], mk_const(solver, v)) \
      for v in self._dds.variables() ])
    self._smt_vars_copy = dict([ (v["name"]+"0", mk_const(solver, v,suffix="0"))\
      for v in self._dds.variables() ])
    self._verbose = verbose
    # just for statistics:
    self._pc_nodes = 0
    self._pc_trans = 0
    self._pc_time = 0

  def check_state(self, formula):
    #print("check state " + str(formula))
    if isinstance(formula, PropVar) or isinstance(formula, Cmp) or \
      isinstance(formula, Bool):
      return ConfigMap(self._dds, self._solver, self._smt_vars, formula)
    elif isinstance(formula, BinCon) and formula.op == "&&":
      return self.check_state(formula.left).land(self.check_state(formula.right))
    elif isinstance(formula, BinCon) and formula.op == "||":
      return self.check_state(formula.left).lor(self.check_state(formula.right))
    elif isinstance(formula, UnCon) and formula.op == "!":
      return self.check_state(formula.arg).negate()
    else:
      assert(isinstance(formula, UnCon) and formula.op in ["E", "A"])
      confmap = self.check_path(formula.arg, formula.op)
      return confmap

  def rename_and_quantify(self, formulas):
    vs = list(self._smt_vars.values())
    vs0 = list(self._smt_vars_copy.values())
    solver = self._solver
    vs_primed = [ self._dds.mk_var(solver, v, suffix = "'") for v in self._dds.variables() ]
    expr = solver.lor(formulas)
    if len(vs) > 0:
      expr = solver.subst(vs, vs_primed, expr) # V to U
      expr = solver.subst(vs0, vs, expr) # V0 to V
      expr = solver.exists(vs_primed, expr)
    return solver.qe_simp(expr)

  def is_F_Conf(self,ltl):
    return isinstance(ltl, UnCon) and ltl.op == "F" and \
      isinstance(ltl.arg, ConfigMapAtom)

  def recycle_result(self, map, ltl, b):
    succs = [ (t["target"], t["guard"] == top) for t in self._dds.out(b)]
    if len(succs) > 0 and all(x in map and gtop for (x,gtop) in succs) and \
        self.is_F_Conf(ltl):
      #print("already computed", succs)
      confmap = ltl.arg._map
      confmapres = set([confmap.get(bx) for (bx,_) in succs] + [confmap.get(b)])
      if len(confmapres) == 1: # conf map returns same result for all succs + b
        return self._solver.lor([map[bx] for (bx,_) in succs])
    return None

  def check_path(self, path_formula, op):
    formula_mod = path_formula if op == "E" else path_formula.negate()
    ltl = self.toLTL(nnf(formula_mod))
    #print("check path " + op + " "+ str(path_formula) + " (toLTL: " + str(ltl) + ")")
    nfa = Automaton(LTLProperty(ltl), solver = self._solver)
    anames = [ t["name"] for t in self._dds._transitions]
    nfa._has_action_vars = has_propvars(path_formula, anames)
    suffix = self._output_suffix + printable(str(path_formula))
    vars = (self._smt_vars, self._smt_vars_copy)
    #if self._verbose:
    #visualize_nfa(nfa, suffix)
    map = {}
    states = list(self._dds.states())
    states.reverse()
    for b in states:
      res0 = self.recycle_result(map, ltl, b)
      if res0 != None:
        map[b] = res0
      else:
        start = time.time()
        prod = ProductConstruction(self._dds, nfa, self._equivalence, start = b, \
          vars=vars)
        self._pc_time += time.time() - start
        self._pc_nodes += len(prod._nodes)
        self._pc_trans += len(prod._transitions)
        #if self._verbose:
        #visualize_product(prod, suffix + str(b))
        if len(prod.final_state_ids()) == 0:
          map[b] = self._solver.false()
        else:
          phi_fin = [ prod._nodes[n]._expr for n in prod.final_state_ids() ]
          #phi = self.rename_and_quantify(phi_fin)
          vs = [ Var.from_array(v) for v in self._dds.variables() ]
          phi = rename_and_quantify(self._solver,phi_fin, vs,\
            list(self._smt_vars.values()), list(self._smt_vars_copy.values()))
          map[b] = phi
    conf_map = ConfigMap(self._dds, self._solver, self._smt_vars, map)
    #print("check path " + op + " "+ str(path_formula))
    #print("result before negation", str(conf_map))
    #print(" has result " + str(conf_map if op == "E" else conf_map.negate()))
    return conf_map if op == "E" else conf_map.negate()

  def toLTL(self, formula):
    if isinstance(formula, PropVar) or isinstance(formula, Cmp) or \
      isinstance(formula, Bool):
      cm = ConfigMap(self._dds, self._solver, self._smt_vars, formula)
      return ConfigMapAtom(cm)
    elif isinstance(formula, BinCon):
      left = self.toLTL(formula.left)
      right = self.toLTL(formula.right)
      op = formula.op
      if isinstance(left, ConfigMapAtom) and isinstance(right, ConfigMapAtom) \
        and op in ["&&", "||"]:
        return left.land(right) if op == "&&" else left.lor(right)
      return BinCon(left, op, right)
    elif isinstance(formula, UnCon) and formula.op in ["X", "Y", "F", "G", "!"]:
      arg = self.toLTL(formula.arg)
      if formula.op == "!" and isinstance(arg, ConfigMapAtom):
        return arg.negate()
      return UnCon(formula.op, arg)
    else:
      assert(isinstance(formula, UnCon) and formula.op in ["E", "A"])
      confmap = self.check_path(formula.arg, formula.op)
      confmap.simplify()
      if self._verbose:
        print("result of " + str(formula) + ": " + str(confmap))
      return ConfigMapAtom(confmap)

  def initial_state_sat(self, confmap):
    solver = self._solver
    b0 = self._dds.initial_states()[0]
    smt_vars = [ self._smt_vars[v["name"]] for v in self._dds.variables() ]
    phi0 = self._dds.init_val_constraints(solver, smt_vars)
    phi = solver.land(phi0 + [confmap.get(b0["id"])])
    return solver.check_sat(phi) != None


def check_CTLstar(dds, ctls_property, equivalence, suffix, verbose):
  for (id,s) in dds._states.items():
    s["final"] = True
  checker = Checker(dds, equivalence, suffix, verbose)
  solver = equivalence._solver
  visualize_dds(dds, suffix)
  confmap = checker.check_state(ctls_property._property)
  confmap.simplify()
  res = checker.initial_state_sat(confmap)
  if verbose:
    print("result of " + str(ctls_property._property) + ": " + str(confmap))
    print("satisfied by initial assignment: " + str(res))

    print("\nSMT checks:   %d" % solver._checks)
    print("solving time: %.2f" % solver._check_time)
    print("simp time:    %.2f" % solver._simp_time)
    print("PC time:      %.2f" % checker._pc_time)
    print("|DDS|:        (%d, %d)" % (len(dds._states), len(dds._transitions)))
    print("|products|:   (%d, %d)" % (checker._pc_nodes, checker._pc_trans))
  return res

def visualize_dds(dds, suffix):
  warnings.filterwarnings("ignore", message = ".*graph is too large.*Scaling.*")
  outdir = "out"
  dds.show(outdir + "/ddsa" + ("_" + suffix if suffix else "") + ".png")

def visualize_product(pc, suffix):
  warnings.filterwarnings("ignore", message = ".*graph is too large.*Scaling.*")
  outdir = "out"
  pc.show(outdir + "/pc" + ("_" + suffix if suffix else "") + ".png")

def visualize_nfa(nfa, suffix):
  warnings.filterwarnings("ignore", message = ".*graph is too large.*Scaling.*")
  outdir = "out"
  nfa.show(outdir + "/nfa" + ("_" + suffix if suffix else "") + ".png")

def printable(s):
  rep = [(" ", "_"), ("(",""), (")",""), (">","gt"), ("<","lt"), ("=","eq"), ("&&","and"), ("||","or")    ]
  for (p,r) in rep:
    s = s.replace(p,r)
  return s