import networkx as nx
from networkx.drawing.nx_agraph import to_agraph, write_dot
import matplotlib.pyplot as plt
import json
import sys

from dds.util import VarType, expr_replace
from dds.expr import Bool, PropVar, Cmp, Num, Var, UnCon, ConfigMapAtom, top
from verification.constraint_graph import ConstraintGraph, Edge, eval_var
from ltl.automaton import extend_label, symstr, symsstr

def state_contradiction(syms, b, dds):
  diffstates = [s["name"] for s in dds._states.values() if s["name"] != b]
  is_diff_state = lambda x: isinstance(x, PropVar) and x.name in diffstates
  return any(is_diff_state(x) for x in syms)

def action_contradiction(syms, a, b, dds):
  confmaps = [cm._map for cm in syms if isinstance(cm, ConfigMapAtom)]
  diffacts = [t["name"] for t in dds._transitions if t["name"] != a]
  if len(confmaps) == 0:
    is_diff_act = lambda x: isinstance(x, PropVar) and x.name in diffacts
    is_neg_act = lambda x: isinstance(x, UnCon) and isinstance(x.arg, PropVar) and x.arg.name == a
    return any(is_diff_act(x) or is_neg_act(x) for x in syms)
  else:
    solver = confmaps[0]._solver
    not_act = [UnCon("!", PropVar(a)).toSMT(solver, {}) for a in diffacts]
    sym_formulas = [m.get(b) for m in confmaps]
    a_act = PropVar(a).toSMT(solver, {})
    requirements = solver.land(sym_formulas + not_act + [a_act])
    sat = solver.check_sat(requirements) != None
    return not sat


def incompatible(s, a, b_id, b, dds, has_action_vars):
  if has_action_vars:
    return state_contradiction(s, b, dds) or action_contradiction(s, a,b_id,dds)
  else:
    return state_contradiction(s, b, dds)

def filter_sym(syms, tgt):
  cmps = [s for s in syms if isinstance(s, Cmp) or isinstance(s,PropVar)]
  confmaps = [s._map.get(tgt) for s in syms if isinstance(s, ConfigMapAtom)]
  return cmps + confmaps

def filter_syms(syms, tgt):
  return [ filter_sym(s, tgt) for s in syms ]

class Node:
  counter = 0
  def __init__(self, dds_state, name, nfa_state, smtexpr, equiv, acc=False):
    self._dds_state = dds_state # just id
    self._dds_state_name = name
    self._nfa_state = nfa_state # just id
    self._expr = smtexpr
    self._equivalence = equiv
    self._is_accepting = acc
    self._id = Node.counter
    Node.counter += 1
    self._expr_string = self.expr_str()
  
  def __eq__(self, other):
    return self._dds_state == other._dds_state and \
      self._nfa_state == other._nfa_state and \
      self._is_accepting == other._is_accepting and \
      self._equivalence.same(self._expr, other._expr)
  
  def __hash__(self):
    return hash( (self._dds_state, self._nfa_state, self._expr) )

  def expr_str(self):
    expr0 = self._equivalence._solver.simplify_more(self._expr)
    expr1 = self._equivalence.simp_dnf(expr0)
    s = self._equivalence._solver.to_string(expr1)
    return expr_replace(s)

  def __str__(self):
    b_str = self._dds_state_name
    q_str = str(self._nfa_state)
    n_str = b_str + " | q" + q_str + " | " + self._expr_string
    return n_str
  
  def to_dict(self):
    return { "dds_state": str(self._dds_state), "nfa_state": str(self._nfa_state), \
      "expr": self._expr_string, "accepting": self._is_accepting }


class ProductConstruction(ConstraintGraph):

  def __init__(self, dds, nfa, equiv, start = None, vars = None):
    self._dds = dds
    self._nfa = nfa
    self._solver = equiv._solver
    if vars != None:
      self._smt_vars = vars[0]
      self._smt_vars_copy = vars[1]
    else:
      self._smt_vars = dict([ (v["name"], self.mk_const(v)) \
        for v in self._dds.variables() ])
      self._smt_vars_copy = None
    self._nodes = {}
    self._transitions = []
    self._equivalence = equiv
    self._start_state = start
    self.compute()

  def add_transition(self, s, t, n, l):
    try:
      e = next(e for e in self._transitions if e._source == s and e._target == t)
      e.add_to_label((n,l))
    except StopIteration:
      e = Edge(s, t, n, l)
      self._transitions.append(e)
  
  def init_expr(self):
    solver = self._solver
    # vars must be ordered correctly!
    vars = [ self._smt_vars[v["name"]] for v in self._dds.variables() ]
    if self._smt_vars_copy == None:
      return solver.land(self._dds.init_val_constraints(solver, vars))
    else:
      varpairs = zip(self._smt_vars.values(), self._smt_vars_copy.values())
      return solver.land([solver.eq(x, x0) for (x,x0) in varpairs])
  
  def compute(self):
    dds = self._dds
    nfa = self._nfa
    vs = dds.variables()
    solver = self._solver

    # create initial expression
    dds_state0 = -1
    equiv = self._equivalence
    nfa_state0 = nfa.initial_state_id()
    init_node = Node(dds_state0, "dummy", nfa_state0, self.init_expr(), equiv)
    mk_dummy = lambda id: \
      {"source": -1, "target": id, "guard":top, "written":[], "name":"dummy"}
    inits = [self._start_state] if self._start_state else \
      [ s["id"] for s in self._dds.initial_states() ]
    trans_dummys = [ mk_dummy(s) for s in inits ]

    self._nodes = {0: init_node}
    self._transitions = []

    def expand(i, node):
      b = node._dds_state
      q = node._nfa_state
      nexts = dds.out(b) if node._dds_state >= 0 else trans_dummys
      for t in nexts:
        b_next = t["target"] # DDS state id
        bx_name = dds._states[b_next]["name"]
        for (q_next, sym) in nfa.out(q):
          # filter sym for compatible labels:
          sym = [s for s in sym if not incompatible(s, t["name"], b_next, \
            bx_name, dds, self._nfa._has_action_vars)]
          if len(sym) == 0:
            continue # compatibility fails
          
          next_expr = solver.qe_simp(self.update_ext(node._expr, t, sym))
          #print(self._start_state, " update ", node._expr, " with ", symstr(sym), " is ", next_expr)
          if solver.check_sat(next_expr) == None:
            # print("unsat ", next_expr)
            continue # not satisfiable, transition is irrelevant

          acc = q_next in nfa.final_state_ids() and b_next in dds.final_state_ids()
          next_node = Node(b_next, bx_name, q_next, next_expr, equiv, acc=acc)

          sym = filter_syms(sym, b_next)
          try:
            j = next(j for (j, n) in self._nodes.items() if next_node == n)
            self.add_transition(i, j, t["name"], sym)
          except StopIteration: # call to next failed, no equivalent node exists
            k = len(self._nodes)
            self._nodes[k] = next_node
            self.add_transition(i, k, t["name"], sym)
            # print("PC new node %d: %s via %s" % (k, str(next_node), symsstr(sym)))
            expand(k, next_node)

    # print("PC initial node: " + str(init_node))
    expand(0, init_node)


  def update_ext(self, smt_expr, trans, labels):
    solver = self._solver
    tgt = trans["target"] # target control state

    def sym2smt(sym):
      # FIXME does LTL and CTL* case at once: actually only one of constrs and
      # confmaps will be non-empty
      #print(" symbol " + (str(list(sym)[0]) if len(sym) > 0 else ""))
      constrs = [c for c in sym if isinstance(c, Cmp)]
      constrs_smt = [c.toSMT(solver, self._smt_vars) for c in constrs]
      # need to filter out action name propositions
      constrsx = [c._map.get(tgt) for c in sym if isinstance(c, ConfigMapAtom)]
      all_constrs = constrs_smt + constrsx
      return solver.land(all_constrs)
    
    e = self.update(smt_expr, trans) # update by transition guard
    nfa_sym_smt = solver.lor([ sym2smt(sym) for sym in labels ])
    return solver.land([e, nfa_sym_smt])

  def has_final_state(self):
    return any(n._is_accepting for n in self._nodes.values())

  def final_state_ids(self):
    return [ i for (i, n) in self._nodes.items() if n._is_accepting ]

  def initial_state_id(self):
    return 0

  def out(self, sid):
    return [t for t in self._transitions if t._source == sid]


  def get_path_to_state_ids(self, target_ids):
    dds = self._dds
    p = []
    n0 = self.initial_state_id()
    seen = [n0]

    def dfs(n, path, seen):
      if n in target_ids:
        return path
      for t in self.out(n):
        b_next = t._target
        if b_next in seen:
          continue
        p = dfs(b_next, path + [(t, b_next)], seen + [b_next])
        if p != None:
          return p
      return None
    
    p = dfs(n0, p, seen)
    return p


  def get_is_accepting_path(self):
    return self.get_path_to_state_ids(self.final_state_ids())


  def compute_witness(self, show=True):
    path = self.get_is_accepting_path()
    vars = self._dds.variables()
    var_names = [v["name"] for v in vars]
    solver = self._solver
    vars_stages = [ self.mk_vars_stage(i) for i in range(0, len(path))]

    # build formula phi combining all constraints
    # initialize with initial assignment
    vars0 = [vars_stages[0][x] for x in var_names]
    phi = solver.land(self._dds.init_val_constraints(solver, vars0))
    action_encodings = [{} for i in range(1, len(path))]

    for i in range(1, len(path)):
      (e, b) = path[i]
      subst_prime = [ (x + "'", v) for (x, v) in vars_stages[i].items() ]
      subst_guard = dict(list(vars_stages[i-1].items()) + subst_prime)

      edge = []
      for l in e._label.items():
        (action, symsets) = l
        nfa_sym_enc = solver.lor([ solver.land([sym.toSMT(solver, vars_stages[i]) \
          for sym in symset ]) for symset in symsets ])
        t = next(t for t in self._dds._transitions if t["name"] == action)
        guard = t["guard"].toSMT(solver, subst_guard)
        inertia = [solver.eq(vars_stages[i-1][v], vars_stages[i][v]) \
          for v in var_names if v not in t["written"] ]
        edge_enc = solver.land([guard, nfa_sym_enc] + inertia)
        action_encodings[i-1][action] = edge_enc
        edge.append(edge_enc)
      
      phi = solver.land([phi, solver.lor(edge)])
      
    m = solver.check_sat(phi, eval=True)
    if m == None:
      print("error: no model found during witness extraction")
      print(phi)
    else:
      if show:
        run = []
        for i in range(1, len(path)):
          (e, b) = path[i]
          acts = [ n for (n, phi) in action_encodings[i-1].items() if m.eval_bool(phi) ]
          val = [ (n, eval_var(m,x,v)) for (x,(n,v)) in zip(vars, vars_stages[i].items())]
          run.append((acts[0], self._dds._states[self._nodes[b]._dds_state], val))
        start_state = self._dds._states[self._nodes[path[0][1]]._dds_state]
        self._dds.show_run(start_state, run)
        m.destroy()

  def show(self, filename):
    g = nx.MultiDiGraph()
    g.add_nodes_from(self._nodes.keys())
    g.add_edges_from([(t._source,t._target) for t in self._transitions])
    pos = nx.spring_layout(g)
    node_labels = dict([ (i, str(n)) for (i, n) in self._nodes.items()])
    edge_labels = dict([((t._source, t._target), symsstr(t._label)) \
      for t in self._transitions])
    # print(pos)
    nx.draw_networkx_nodes(g, pos, node_size = 50)
    nx.draw_networkx_edges(g, pos, arrows=True)
    nl = nx.draw_networkx_labels(g, pos, labels=node_labels, font_size=10, font_family='sans-serif')
    el = nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, font_size=10)
    write_dot(g, 'pc.dot')

    A = to_agraph(g)
    A.node_attr['fontname'] = "Arial"
    A.edge_attr['fontname'] = "Arial"
    A.node_attr['fontsize'] = "11"
    A.edge_attr['fontsize'] = "10"
    A.edge_attr['arrowsize'] = "0.6"

    for t in self._transitions:
        edge = A.get_edge(t._source,t._target)
        edge.attr['label'] = " " + t.label_str()

    for (i,node) in self._nodes.items():
        n = A.get_node(i)
        n.attr['shape']='record'
        n.attr['margin']="0.05,0.005"
        n.attr['height'] = "0.3"
        n.attr['fontsize'] = "11" if len(str(node)) < 60 else "10"
        if node._is_accepting:
          n.attr['style']='filled'
          n.attr['fillcolor']='gray'
        n.attr['label'] = str(node)

    A.layout('dot')  
    A.draw(filename)

  def export_json(self):
    nodes = ["" for i in range(0, len(self._nodes))]
    for (i, s) in self._nodes.items():
      nodes[i] = s.to_dict()
    
    edges = [ e.to_dict() for e in self._transitions ]

    graph = { "nodes": nodes, "arcs": edges }
    with open('pc.json', 'w') as f:
      json.dump(graph, f, indent=2)
