import networkx as nx
from networkx.drawing.nx_agraph import to_agraph
from dds.expr import mk_var, mk_conj
from smt.cvc5solver import *
from copy import deepcopy

from dds.expr import Expr, Bool, Num, Var, PropVar, Cmp, UnCon, BinCon, \
  ConfigMapAtom, top, bot
from dds.dds import DDS

last = PropVar("last")

not_last = UnCon("!", last)

def simp_sym_set(s):
  if any(UnCon("!", f) in s for f in s):
    return None
  return set(s)

def simp_bincon_unique(e):
  def collect(e, op):
    if isinstance(e, BinCon) and e.op == op:
      return collect(e.left, op) + collect(e.right, op)
    else:
      return [e]
    
  def build(es, op):
    return es[0] if len(es) == 1 else BinCon(es[0], op, build(es[1:], op))
  
  for op in ["&&", "||"]:
    if isinstance(e, BinCon) and e.op == op:
      args = [ simp_bincon_unique(a) for a in list(set(collect(e, op))) ]
      return build(args, op)
  return e

def dnf(e):
  if isinstance(e, BinCon) and e.op == "&&":
    (l, r) = (e.left, e.right)
    if isinstance(l, BinCon) and l.op == "||":
      return BinCon(BinCon(l.left, "&&", r), "||",BinCon(l.right, "&&", r))
    elif isinstance(r, BinCon) and r.op == "||":
      return BinCon(BinCon(r.left, "&&", l), "||",BinCon(r.right, "&&", l))
  return e


def simp_bincon(x, op, y):
  #print("simp ", x, op, y)
  assert (op == "&&" or op == "||")
  if x == top:
    return y if op == "&&" else top
  elif y == top:
    return x if op == "&&" else top
  elif x == bot:
    return bot if op == "&&" else y
  elif y == bot:
    return bot if op == "&&" else x
  elif x == y:
    return x
  elif isinstance(x, UnCon) and x.op == "F" and x.arg == y:
    return y if op == "&&" else x
  elif isinstance(x, UnCon) and x.op == "G" and x.arg == y:
    return x if op == "&&" else y
  elif isinstance(y, UnCon) and y.op == "F" and y.arg == x:
    return x if op == "&&" else y
  elif isinstance(y, UnCon) and y.op == "G" and y.arg == x:
    return y if op == "&&" else x
  elif op == "&&" and isinstance(y, UnCon) and y.op == "G" and \
    isinstance(x, BinCon) and x.left == y.arg and x.op == "||": 
    # (phi || psi) && G phi
    return simp_bincon(x.right, op, y)
  elif op == "&&" and isinstance(y, UnCon) and y.op == "G" and \
    isinstance(x, BinCon) and x.right == y.arg and x.op == "||": 
    # (phi || psi) && G psi
    return simp_bincon(x.left, op, y)
  e = simp_bincon_unique(BinCon(x, op, y))
  return simp_bincon_unique(dnf(e)) # simp after DNF is crucial for termination

def symstr(sset):
  s = ""
  for x in sset:
    s = s + ", " if len(s) > 0 else ""
    s = s + str(x)
  return s.replace("==", "=")

def symsstr(ssets):
  ss = ""
  for sset in ssets:
    ss = ss + ("," if len(ss) > 0 else "") + "{" + symstr(sset) + "}"
  return ss

def extend_label(l, labels):
  if l == set(): # minimal element
    return [l]
  
  def neg(t):
    if isinstance(t, UnCon):
      return t.arg
    else:
      return UnCon("!", t)

  if len(l) == 0:
    return labels
  if len(labels) == 0:
    return [l]
  
  ls_all = list(labels) + [l]
  for l1 in ls_all:
    ls_rest = [ l for l in ls_all if l != l1]
    if all([l1.issubset(l2) for l2 in ls_rest]):
      return [l1]
  
  for term in l:
    l1 = set(l)
    l1.discard(term)
    l1.add(neg(term))
    if any([l2 for l2 in labels if l1 == l2]):
      labels = [l2 for l2 in labels if not (l1 == l2)]
      l.discard(term)
      return extend_label(l, labels)
  
  return labels + [l]
  

class Automaton:
  def __init__(self, formula = None, solver=None, deterministic = False):
    self._formula = formula
    # deterministic: use negated constraint !c rather than top in delta("c")
    # (as e.g. needed for monitoring, but not model checking)
    self.deterministic = deterministic
    self._last = PropVar("last")
    self._not_last = UnCon("!", last)
    self._edges = {}
    self._states = {}
    self._top_state = None
    self._has_action_vars = None # set later, useful for CTL* compatibility

    if formula: # can also be empty to construct a dummy object
      # to check satisfiability of edge labels
      self._solver = solver
      self._vars = dict([(str(v),mk_var(self._solver,v)) \
        for v in formula._property.vars()])
      self.build()
  
  def contradictory(self, label):
    return self._last in label and self._not_last in label

  def add_edge(self, e):
    (src, tgt) = (e["source"], e["target"])
    label = e["label"]#.difference({self._last, self._not_last})
    if (src, tgt) in self._edges:
      label = extend_label(label, self._edges[(src, tgt)])
      self._edges[(src, tgt)] = label
    else:
      self._edges[(src, tgt)] = [label]

  def build(self):
    f = self._formula._property
    self._states = { 0: f }
    accept = PropVar("accept")
    reject = PropVar("reject")
    def expand(i, psi):
      # combine labels for edges to True upfront to avoid unnecessary end nodes
      succs = self.delta(psi)
      tsuccs = []
      for s in [ s for (f, s) in succs if f == top ]:
        tsuccs = extend_label(s, tsuccs)
      succs = [ sc for sc in succs if sc[0] != top ] + [(top,s) for s in tsuccs]

      for (psix, s) in succs:
        # by NFA semantics, omit sink state bot; omit s if it has last and !last
        if self.contradictory(s) or (psix == bot and not self.deterministic):
          continue

        if self._last in s and psix == top: # redirect to accept state
          psix = accept
        if self._last in s and psix == bot and self.deterministic: # redirect to end state
          psix = reject
    
        try:
          j = next(j for (j, chi) in self._states.items() if chi == psix)
          self.add_edge({"source": i, "target": j, "label":s})
        except StopIteration:
          k = len(self._states)
          self._states[k] = psix
          self.add_edge({"source": i, "target": k, "label":s})
          # print("new state " + str(k) + ": " + str(psix))
          if psix not in [accept, reject]:
            expand(k, psix)
    
    # print("initial state 0: " + str(f))
    expand(0,f)

    tops = [j for (j, chi) in self._states.items() if chi == top]
    self._top_state = tops[0] if len(tops) > 0 else None
    ends = [j for (j, chi) in self._states.items() if chi == accept]
    fs = ends if len(ends) > 0 else []
    fs += [self._top_state] if self._top_state else []
    self._final_state_ids = fs


  def delta(self, f):
    last = self._last
    not_last = self._not_last
    delta_last = [ (top, {last}), (bot, {not_last}) ]
    delta = lambda f: self.delta(f)
    and2 = lambda x, y: self.combine("&&", x, y)
    or2 = lambda x, y: self.combine("||", x, y)

    if isinstance(f, Bool):
      return [(f, set([]))] # takes care of Top and Bot
    elif isinstance(f, PropVar):
      return [ (top, {f}), (bot, set()) ] # p in B or A
    elif isinstance(f, Cmp):  # constraint
      return [ (top, {f}), (bot, set()) ] if not self.deterministic else \
        [ (top, {f}), (bot, { f.negate() }) ]
    elif isinstance(f, UnCon) and f.op == "!" and isinstance(f.arg, PropVar):
      return [ (top, {f}), (bot, {f.arg}) ] # not b or b not a
    elif isinstance(f, ConfigMapAtom):
      return [ (top, {f}), (bot, set()) ] # base case for CTL*
    elif isinstance(f, UnCon) and f.op == "!" and isinstance(f.arg, ConfigMapAtom):
      return [ (top, {f}), (bot, {f.arg}) ] # negated atom for CTL*
    elif isinstance(f, BinCon) and f.op == "&&": # conjunction
      return and2(delta(f.left), delta(f.right))
    elif isinstance(f, BinCon) and f.op == "||": # disjunction
      return or2(delta(f.left), delta(f.right))
    elif isinstance(f, UnCon) and f.op == "X": # next (assume no specific action)
      return [ (f.arg, {not_last}), (bot, {last}) ]
    elif isinstance(f, UnCon) and f.op == "Xw": # weak next
      return [ (f.arg, {not_last}), (top, {last}) ]
    elif isinstance(f, UnCon) and f.op == "F": # diamond
      return or2(delta(f.arg), delta(UnCon("X", f)))
    elif isinstance(f, UnCon) and f.op == "G": # box
      return and2(delta(f.arg), or2(delta(UnCon("X", f)), delta_last))
    elif isinstance(f, BinCon) and f.op == "U": # until
      return or2(delta(f.right), and2(delta(f.left), delta(UnCon("X", f))))
    else:
      print("unknown formula " + str(f))
      assert(False)


  def sat(self, cs):
    if any( isinstance(c, ConfigMapAtom) for c in cs):
      return True # do not check for now; does not harm soundness
    phi = self._solver.land([ c.toSMT(self._solver, self._vars) for c in cs ])
    return self._solver.check_sat(phi) != None

  def combine(self, op, xs, ys):
    combs = [ (simp_bincon(x, op, y), simp_sym_set(s.union(t))) \
      for (x, s) in xs for (y, t) in ys if self.sat(s.union(t)) ]
    return [ (f, s) for (f, s) in combs if s != None ]

  def initial_state_id(self):
    return 0
    
  def final_state_ids(self):
    return self._final_state_ids

  def out(self, sid):
    return [(t,ls) for ((s,t), ls) in self._edges.items() if s == sid]

  def get_alphabet_with_negations(self, only_current = False):
    #TODO could be filtered for satisfiability
    def combinations(cs):
      if len(cs) == 0:
        return [set([])]
      else:
        (c, cneg) = cs[0]
        combs = combinations(cs[1:])
        return [ s.union({c}) for s in combs] + [s.union({cneg}) for s in combs]

    if only_current:
      csx = [ (c, c.negate()) for c in self._formula.get_constraints() if \
        not any( v.is_back for v in c.vars() ) ]
    else:  
      csx = [ (c, c.negate()) for c in self._formula.get_constraints() ]
    return combinations(csx)

  def determinize(self):
    alph = self.get_alphabet_with_negations()
    alph_curr = self.get_alphabet_with_negations(only_current=True)
    dfa = Automaton()
    dfa._formula = self._formula
    new_init_state = { self.initial_state_id() }
    ids_for_state_names = { str(new_init_state) : 0}
    dfa._final_state_ids = [] # first state is never final
    
    def progress(curr_state, only_current = False):
      if not str(curr_state) in dfa._states.values():
        curr_state_name = str(curr_state)
        curr_state_id = ids_for_state_names[curr_state_name]
        dfa._states[curr_state_id] = curr_state_name
        out = [ (t,lab) for (s,t) in self._edges for lab in self._edges[(s,t)] \
          if s in curr_state ]
        alphabet = alph_curr if only_current else alph
        for sym in alphabet:
          reached = set([ t for (t, l) in out if self.sat(l.union(sym)) ])
          if str(reached) not in ids_for_state_names:
            id = len(ids_for_state_names)
            ids_for_state_names[str(reached)] = id
            if any( s in self.final_state_ids() for s in reached):
              dfa._final_state_ids.append(id)
          else:
            id = ids_for_state_names[str(reached)]
          dfa.add_edge({"source": curr_state_id, "target": id, "label":sym})
          progress(reached)
    
    progress(new_init_state, only_current=True)
    return dfa

  def to_dds(self):
    vars = []
    for v in self._formula.vars():
      if not any( v.basename() == x["name"] for x in vars):
        vars.append({"name": v.basename(), "type": v._type, "initial": 0})
    
    states = []
    for (id, n) in self._states.items():
      state = { "id": id, "name": n }
      state["initial"] = id == self.initial_state_id()
      state["final"] = id in self.final_state_ids()
      states.append(state)
    
    trans = []
    for ((s,t), ls) in self._edges.items():
      for l in ls:
        guard = mk_conj(l)
        name = "a" + str(len(trans))
        trans.append({"source": s, "target": t, "guard": guard, "name": name, \
          "written": [v["name"] for v in vars] })
    
    dds_array = {"states": states, "transitions": trans, "variables": vars, \
      "name": "DFA for " + str(self._formula) }
    return DDS(dds_array)


  def export_json(self):
    nodes = ["" for i in range(0, len(self._states))]
    for (i, s) in self._states.items():
      nodes[i] = str(s)
    
    edges = []
    for ((s,t), ls) in self._edges.items():
      t = {"source": s, "target": t, "label": symsstr(ls)}
      edges.append(t)

    graph = { "nodes": nodes, "arcs": edges }
    with open('nfa.json', 'w') as f:
      json.dump(graph, f, indent=2)

  def show(self, filename):
    g = nx.MultiDiGraph()
    g.add_edges_from(list(self._edges))
    pos = nx.spring_layout(g)
    nx.draw_networkx_nodes(g, pos, node_size = 100)
    nx.draw_networkx_edges(g, pos, arrows=True)
    nx.selfloop_edges(g)
    
    A = to_agraph(g)
    A.node_attr['fontname'] = "Arial"
    A.edge_attr['fontname'] = "Arial"
    A.node_attr['fontsize'] = "11"
    A.edge_attr['fontsize'] = "10"
    A.edge_attr['arrowsize'] = "0.6"

    for ((s,t), ls) in self._edges.items():
        edge = A.get_edge(s,t)
        edge.attr['label'] = " " + symsstr(ls)
        if s == t:
          edge.attr['tailport'] = "ne" 
          edge.attr['headport'] = "se"
    for (i,psi) in self._states.items():
        n = A.get_node(i)
        n.attr['shape']='box'
        n.attr['margin']="0.1,0.005"
        n.attr['height']="0.3"
        n.attr['label'] = "q" + str(i) + ": " + str(psi)
        if i in self.final_state_ids():
          n.attr['style']='filled'

    A.layout('dot')  
    A.draw(filename)

  def latex(self):
    for (i,psi) in self._states.items():
      style = ", final" if i in self.final_state_ids() else ""
      print("\\node[state %s] (%i) {$q_%i\\colon%s$};" % (style, i, i, psi))
    
    for ((s,t), ls) in self._edges.items():
      style = ", loop" if s == t else ""
      label = symsstr(ls)
      label = label.replace("last", "\\lambda").replace("!", "\\neg")
      label = label.replace("(", "").replace(")", "")
      label = label.replace("{", "\\{").replace("}", "\\}")
      label = label.replace("<=", "\\,{\\leq}\\,")
      label = label.replace("<", "\\,{<}\\,")
      print("\\draw (%i) -- node[action %s] {$%s$} (%i);" % (s, style,label, t))