from smt.solver import *
import networkx as nx
from networkx.drawing.nx_agraph import to_agraph
import json
import sys
import time

from ltl.automaton import symsstr 
from dds.util import VarType, expr_replace
from dds.dds import val_str
import dds.expr as expr
from dds.expr import Bool, Expr, Num, Var, Charstr, Cmp, BinOp, BinCon, top

# various helper functions
def eval_var(m, x, val):
  if x["type"] == VarType.int:
    return m.eval_int(val)
  elif x["type"] == VarType.bool:
    return m.eval_bool(val)
  else:
    return m.eval_real(val)

class Graph:
  def init(self):
    self._init_node_id = None
    self._transitions = None

  def initial_node_ids(self):
    return self._init_node_ids
  
  def out(self, sid):
    return [t for t in self._transitions if t._source == sid]
  
  def ins(self, tid):
    return [t for t in self._transitions if t._target == tid]
  
  def get_path_to_node_pred(self, pred):
    seen = self.initial_node_ids()

    def dfs(n, path, seen):
      if pred(self._nodes[n]):
        return path
      for t in self.out(n):
        b_next = t._target
        if b_next in seen:
          continue
        p = dfs(b_next, path + [(t, b_next)], seen + [b_next])
        if p != None:
          return p
      return None
    
    for n0 in self.initial_node_ids():
      p = dfs(n0, [], seen)
      if p != None:
        return (n0, p)
    return None
  
  def get_path_to_node(self, n):
    return self.get_path_to_node_pred(lambda n2: n2 is n)
  
  def reachable(self, start_ids):
    seen = set([])

    def reach(n):
      seen.add(n)
      for t in [ t for t in self.out(n) if t._target not in seen]:
        reach(t._target)
    
    for id in start_ids:
      reach(id)
    return seen

# class to represent node in CG
class Node:
  def __init__(self, stateid, statename, smtexpr, equiv):
    self._state = stateid
    self._statename = statename
    self._expr = smtexpr
    self._equivalence = equiv
    self._expr_string = self.expr_str()
    self._hash = hash( (self._state, self._expr) )
  
  def __eq__(self, other):
    same = self._state == other._state and \
      self._equivalence.same(self._expr, other._expr)
    return same
  
  def __hash__(self):
    return self._hash

  def expr_str(self):
    expr1 = self._equivalence.simp_lit_conjunction(self._expr)
    s = self._equivalence._solver.to_string(expr1).replace("And(", "(")
    return expr_replace(s)

  def __str__(self):
    s = self._statename + "|" + self._expr_string
    return s
  
  def to_dict(self):
    return { "state": self._state, "expr": self._expr_string }


class Edge:
  def __init__(self, src, tgt, name, syms = {}, constraint_free = False):
    self._source = src # just id
    self._target = tgt # just id
    self._name = name
    self._label = {name: syms}
    self._constraint_free = constraint_free # guard is True and no vars written
  
  def add_to_label(self, lab):
    (name, syms) = lab
    if name in self._label:
      self._label[name] = extend_label(syms, self._label[name])
    else:
      self._label[name] = syms

  def label_str(self):
    s = ""
    for (n, syms) in self._label.items():
      ss = symsstr(syms)
      s += (n + ":" + ss + " ") if len(syms) > 0 else n
    return s
  
  def __str__(self):
    return str(self._source) + " - " + self.label_str() + " - " + str(self._target)

  def to_dict(self):
    s = [ {"action": n, "symbol": symsstr(syms)} for (n, syms) in self._label.items() ]
    return {"source": self._source, "target": self._target, "label": s}

cgx_cache = {}

# class to represent CG
class ConstraintGraph(Graph):

  def __init__(self, dds, equiv, start_state = None, cg0 = None, \
      vars = ([], []), init_constr = None):
    self._dds = dds
    self._solver = equiv._solver
    # equivalence relation used for the graph
    self._equivalence = equiv
    # if start_state set, take as initial and construct wrt renaming
    self._init_dds_state = start_state
    self.init_constr = init_constr # only used with start state in Paolo mode
    self._init_node_ids = [] # set later
    self._nodes = {} # set later
    self._transitions = [] # set later
    if vars != ([],[]):
      self._smt_vars = vars[0]
      self._smt_vars_copy = vars[1]
    else:
      self._smt_vars = dict([ (v["name"], self.mk_const(v)) \
        for v in self._dds.variables() ])
      self._smt_vars_copy = dict([ (v["name"]+"0", self.mk_var(v, suffix="0")) \
        for v in self._dds.variables() ])

    self.compute()

  
  # create SMT variable with name v and given suffix
  def mk_var(self, v, suffix = ""):
    return expr.mk_var(self._solver, Var.from_array(v), suffix)
  
  def mk_vars_stage(self, i):
    vars = self._dds.variables()
    return dict([ (v["name"], self.mk_const(v, suffix=str(i))) for v in vars ])

  def mk_const(self, v, suffix = ""):
    return expr.mk_const(self._solver, Var.from_array(v), suffix)
  
  # auxiliary function to expand constraint graph in DFS fashion
  def expand(self, i, node):
    sys.stdout.flush()
    solver = self._solver
    for t in self._dds.out(node._state):
      next_state = t["target"] # DDS state id
      next_expr = solver.qe_simp(self.update(node._expr, t))
      
      if solver.check_sat(next_expr) == None:
        continue # not satisfiable, transition is not reachable

      next_expr = solver.simplify(next_expr)
      next_name = self._dds._states[next_state]["name"]
      next_node = Node(next_state, next_name, next_expr, self._equivalence)
      cf = t["guard"] == Bool(True) and len(t["written"]) == 0
      try:
        (j,n) = next((j,n) for (j, n) in self._nodes.items() if next_node == n)
        self._transitions.append(Edge(i, j, t["name"], constraint_free = cf))
        # print("CG existing node: " + str(n))
      except StopIteration: # next failed, no equivalent node exists
        k = len(self._nodes)
        self._nodes[k] = next_node
        self._transitions.append(Edge(i, k, t["name"], constraint_free = cf))
        # print("CG new node: " + str(next_node))
        self.expand(k, next_node)

  # main function to compute constraint graph
  def compute(self):
    vs = self._dds.variables()
    solver = self._solver
    dds = self._dds
    use_v0 = self._init_dds_state != None # whether we start from intermediate

    # create initial expression
    rhss = self._smt_vars_copy.values() if use_v0 else dds.init_values(solver)
    varpairs = zip(self._smt_vars.values(), rhss)
    expr0 = solver.land([solver.eq(x, x0) for (x,x0) in varpairs])
    if self.init_constr != None:
      expr0 = solver.simplify(solver.land([expr0, self.init_constr]))

    # create initial nodes
    if use_v0:
      init_id = self._init_dds_state
      inits = [(init_id, self._dds._states[init_id]["name"])]
    else: # normal case: start from initial state of DDS
      inits = [ (s["id"], s["name"]) for s in self._dds.initial_states()]
    init_nodes = [ Node(id, n, expr0, self._equivalence) for (id, n) in inits]

    # TESTHACK: copy/expand other cg
    if use_v0 and len(init_nodes) == 1:
      init_node = init_nodes[0]
      # case 1: the initial state has a single successor with guard True in the
      #         DDS, and for this guy we already have a CG: copy + add edge
      nexts = dds.out(init_node._state)
      if len(nexts) == 1:
        t = nexts[0]
        if t["guard"] == top and t["written"] == [] and t["target"] in cgx_cache: 
          init_node = init_nodes[0]
          cg = cgx_cache[t["target"]]
          self._transitions = [t for t in cg._transitions]
          self._nodes = dict([ n for n in cg._nodes.items()])
          init_id = len(self._nodes)
          self._nodes[init_id] = init_node
          self._init_node_ids.append(init_id)
          self._transitions.append(Edge(init_id, cg._init_node_ids[0], \
            t["name"], constraint_free = True))
          return
      # case 2: the initial state has a predecessor with guard True in the DDS,
      #         and for this guy we already have a CG: copy + restrict
      dds_id = init_node._state
      precs = dds.ins(dds_id)
      for t in precs:
        if t["guard"] == top and t["written"] == [] and t["source"] in cgx_cache:
          cg = cgx_cache[t["source"]]
          cginit = cg._init_node_ids[0]
          self._transitions = [t for t in cg._transitions if t._source !=cginit]
          cg_start_trans = [t for t in cg._transitions if t._source == cginit \
            and t._name == trans["name"] and dds_id == cg._nodes[t._target]._state]
          assert(len(cg_start_trans) == 1)
          init_node_id = cg_start_trans[0]._target
          reach = cg.reachable([init_node_id])
          self._nodes = dict([n for n in cg._nodes.items() if n[0] in reach])
          self._init_node_ids = [init_node_id]
          return

    # initialize node and transition sets
    for init_node in init_nodes:
      try:
        id = next(j for (j, n) in self._nodes.items() if n == init_node)
      except StopIteration: # next failed
        id = len(self._nodes)
        self._nodes[id] = init_node
      self._init_node_ids.append(id)
      self.expand(id, init_node)

  # auxiliary function to update SMT expression by guard of transition t
  def update(self, smt_expr, t):
    vs = self._dds.variables()
    solver = self._solver
    smtvs = self._smt_vars
    vs_written = [ x for (v,x) in smtvs.items() if v in t["written"] ]
    vs_bef = dict([ (v["name"], self.mk_var(v, suffix="*")) \
      for v in vs if v["name"] in t["written"] ]) # auxiliary
    smt_expr_subst = solver.subst(vs_written, vs_bef.values(), smt_expr)
    if len(t["written"]) == 0: # shortcut
      guardx = t["guard"].toSMT(solver, smtvs)
      exprx = solver.land([smt_expr, guardx])
    else:
      subst_keep = [ xv for xv in smtvs.items() if xv[0] not in t["written"] ] 
      if self._dds.has_lookback():
        subst_bef = [ (x + "-", vs_bef[x]) for (x, v) in smtvs.items() ]
        guard_subst = dict(subst_bef + subst_keep + list(smtvs.items()))
      else:
        subst_after = [ (x + "'", v) for (x, v) in smtvs.items() ]
        guard_subst = dict(list(vs_bef.items()) + subst_keep + subst_after)
      guardx = t["guard"].toSMT(solver, guard_subst)
      exprx = solver.land([smt_expr_subst, guardx])
      if not solver.is_true(exprx):
        exprx = solver.exists(list(vs_bef.values()), exprx)
    return exprx


  def val_constraints(self, val, vars):
    def val(v):
      v0 = v["initial"]
      return Bool(v0) if v["type"] == VarType.bool else Num(v0)
    return [Cmp("==", val(v),Var(v["name"],v["type"])) for v in self._variables]

  def compute_run_to_target_valuation(self, b0, path, val, show=True):
    vars = self._dds.variables()
    solver = self._solver
    n = len(path) + 1 # number of stages
    vars_stages = [ self.mk_vars_stage(i) for i in range(0, n)]

    vars0 = [vars_stages[0][x["name"]] for x in vars]
    phi = solver.land(self._dds.init_val_constraints(solver, vars0))
    action_encodings = [{} for i in range(1, n)]

    for i in range(1, n):
      (e, b) = path[i-1]
      subst_prime = [ (x + "'", v) for (x, v) in vars_stages[i].items() ]
      subst_guard = dict(list(vars_stages[i-1].items()) + subst_prime)

      edge = []
      for l in e._label.items():
        (action, symsets) = l
        t = next(t for t in self._dds._transitions if t["name"] == action)
        guard = t["guard"].toSMT(solver, subst_guard) # assumption: no lookback
        inertia = [solver.eq(vars_stages[i-1][v], vars_stages[i][v]) \
          for v in vars_stages[0] if v not in t["written"] ]
        edge_enc = solver.land([guard] + inertia)
        action_encodings[i-1][action] = edge_enc
        edge.append(edge_enc)
      
      phi = solver.land([phi, solver.lor(edge)])
    
    def smtval(v, a):
      if v["type"] == VarType.bool:
        return solver.true() if a else solver.neg(solver.true())
      # CVC5 returns string if value irrelevant, set to 0
      if v["type"] == VarType.int:
        return solver.num(a if isinstance(a,int) else 0)
      return solver.real(a if isinstance(a,float) else 0)
    
    lvars = vars_stages[n-1]
    val = dict(val)
    valc = [solver.eq(lvars[v["name"]], smtval(v,val[v["name"]])) for v in vars]
    phi = solver.land([phi] + valc)
    m = solver.check_sat(phi, eval=True)
    if m == None:
      print("error: no model found during witness extraction")
      print(phi)
    else:
      if show:
        run = []
        for i in range(1, n):
          (e, b) = path[i-1]
          acts = [ n for (n, phi) in action_encodings[i-1].items() if m.eval_bool(phi) ]
          val = [ (n, eval_var(m,x,v)) for (x,(n,v)) in zip(vars, vars_stages[i].items())]
          assert(len(acts) > 0)
          run.append((acts[0], self._dds._states[self._nodes[b]._state], val))
        self._dds.show_run(self._dds._states[b0], run)
      m.destroy()


  # check data-aware soundness: part 1
  def has_dead_transitions(self, verbose):
    get_ddsnode = lambda n: self._dds._states[self._nodes[n]._state]["id"]
    cg_trans = set([ t.label_str() for t in self._transitions ])
    dds_trans = set([ t["name"] for t in self._dds._transitions ])
    dead = [ t for t in dds_trans if not (t in cg_trans) ]
    if len(dead) > 0:
      if verbose:
        print("there are dead transitions: ", dead)
      return True
    return False

  # check data-aware soundness: part 1
  def has_dirty_termination(self, verbose):
    dsts = dict([ (i, p["name"]) for (i,p) in self._dds._states.items() \
      if "superfinal" in p ])
    dirty = [ dsts[n._state] for n in self._nodes.values() if n._state in dsts ]
    if len(dirty) > 0:
      if verbose:
        s = set(dirty) if len(dirty) > 1 else dirty[0]
        print("dirty termination states are reachable: ", s)
      return True
    return False

  def print_deadlock(self, n, model):
    val = [ (n, eval_var(model,x,v)) for (x,(n,v)) in \
      zip(self._dds.variables(), list(self._smt_vars.items()))]
    name = self._dds._states[n._state]["name"]
    print("deadlock in %s with valuation%s" % (name, val_str(val)))
    (b0, path) = self.get_path_to_node(n)
    self.compute_run_to_target_valuation(b0, path, val, show=True)

  # check data-aware soundness: part 2
  def has_deadlocks(self, verbose):
    vs = self._dds.variables()
    solver = self._solver
    equiv = self._equivalence
    cgx_cache = {}

    vs_smt = list(self._smt_vars.values())
    vs_primed = [ self.mk_var(v, suffix = "_") for v in vs ]
    final_ids = self._dds.final_state_ids()
    checked_nodes = []

    if not any(n for n in self._nodes.values() if n._state in final_ids ):
      if verbose:
        print("CG has no final state ids")
      return True

    # compute extended CG for every control state in DDS
    for s in [s for s in self._dds._states.values() if not s["id"] in final_ids]:
      cgx = ConstraintGraph(self._dds, equiv, start_state = s["id"],
        vars = (self._smt_vars, self._smt_vars_copy))
      cgx_cache[s["id"]] = cgx
      cgx.show("out/cg" + "_" + s["name"] + ".png")
      cgx_final = [ i for (i,n) in cgx._nodes.items() if n._state in final_ids ]
      to_end_exprs = []
      for id in cgx_final:
        fin = cgx._nodes[id]
        expr = solver.subst(cgx._smt_vars.values(), vs_primed,fin._expr) #V to U
        vs_smt0 = list(cgx._smt_vars_copy.values())
        expr = solver.subst(vs_smt0, vs_smt, expr) # V0 to V
        expr = solver.exists(vs_primed, expr) if len(vs_primed) > 0 else expr
        to_end_exprs.append(solver.neg(expr))
      solver.push()
      solver.require(solver.land(to_end_exprs))
      s_nodes = [n for n in self._nodes.items() if n[1]._state == s["id"]]
      for (nid, n) in s_nodes:
        ins = [t for t in self._transitions if t._target == nid]
        checked_nodes.append(nid)
        if len(ins) > 0 and all (t._source in checked_nodes and \
          t._constraint_free for t in ins):
          continue # incoming edges have True, and soundness already checked
        m = solver.check_sat(n._expr, eval=True)
        if m:
          if verbose:
            self.print_deadlock(n, m)
          m.destroy()
          solver.pop()
          return True
      solver.pop()
    return False
  
  def has_deadlocks3(self, verbose): # version 3, a la Paolo
    solver = self._solver
    vs_smt = list(self._smt_vars.values())
    vs_primed = [ self.mk_var(v, suffix = "_") for v in self._dds.variables() ]
    final_ids = self._dds.final_state_ids()
    vars_tuple = (self._smt_vars, self._smt_vars_copy)
    
    for (id, n) in self._nodes.items():
      s = self._dds._states[n._state]
      if s["id"] in final_ids:
        continue
      cgx = ConstraintGraph(self._dds, self._equivalence, start_state = s["id"],
        vars = vars_tuple, init_constr = n._expr)
      #cgx.show("out/cg" + "_" + s["name"] + "_" + str(id) + ".png")
      cgx_final = [n for n in cgx._nodes.items() if n[1]._state in final_ids]
      to_end_exprs = []
      for (id, fin) in cgx_final:
        expr = solver.subst(cgx._smt_vars.values(), vs_primed, fin._expr) # V2U
        expr = solver.subst(cgx._smt_vars_copy.values(), vs_smt, expr) # X2V
        expr = solver.exists(vs_primed, expr) if len(vs_primed) > 0 else expr
        to_end_exprs.append(solver.neg(expr))
      m = solver.check_sat(solver.land([n._expr] + to_end_exprs), eval=True)
      if m:
        if verbose:
          self.print_deadlock(n, m)
        m.destroy()
        return True
    return False.as_integer_ratio

  def is_sound(self, verb):
    return not (self.has_dead_transitions(verb) or \
      self.has_dirty_termination(verb) or self.has_deadlocks(verb))


  def show(self, filename):
    g = nx.DiGraph()
    g.add_nodes_from(self._nodes.keys())
    g.add_edges_from([(t._source, t._target) for t in self._transitions])
    pos = nx.spring_layout(g)
    n = nx.draw_networkx_nodes(g, pos, node_size=200)
    e = nx.draw_networkx_edges(g, pos, arrows=True)

    A = to_agraph(g) 
    A.node_attr['fontname'] = "Arial"
    A.edge_attr['fontname'] = "Arial"
    A.node_attr['fontsize'] = "11"
    A.edge_attr['fontsize'] = "10"
    A.edge_attr['arrowsize'] = "0.6"

    for t in self._transitions:
        edge = A.get_edge(t._source, t._target)
        edge.attr['label'] = " " + t.label_str()
        if t._source == t._target:
          edge.attr['tailport'] = "ne"
          edge.attr['headport'] = "se"
    for (i,node) in self._nodes.items():
        n = A.get_node(i)
        n.attr['shape'] = 'record'
        n.attr['margin'] = "0.1,0.005"
        n.attr['height'] = "0.3"
        if node._state in self._dds.final_state_ids():
          n.attr['style']='filled'
          n.attr['fillcolor']='gray'
        n.attr['node_size'] = len(str(node)) * 100
        n.attr['label'] = str(node)

    A.layout('dot')
    A.draw(filename)

  def export_json(self):
    nodes = ["" for i in range(0, len(self._nodes))]
    for (i, s) in self._nodes.items():
      nodes[i] = str(s)
    
    edges = []
    for e in self._transitions:
      (s, t) = (e._source, e._target)
      t = {"source": s, "target": t, "label": e["name"]}
      edges.append(t)

    graph = { "nodes": nodes, "arcs": edges }
    with open('cg.json', 'w') as f:
      json.dump(graph, f, indent=2)
