from sys import maxsize
from functools import reduce
import networkx as nx
from networkx.drawing.nx_agraph import to_agraph
from copy import deepcopy

from dds.util import expr_replace
from dds.read import VarType
from dds.expr import top, Bool, Cmp, BinOp, BinCon, UnCon, Var, Num
from dds.expr import VarFlipper, mk_var
from verification.abstraction_equivalence import *


def conjuncts(e):
  if isinstance(e, BinCon) and e.op == "&&":
    return conjuncts(e.left) + conjuncts(e.right)
  return [e]

def expr_constants(e):
  if isinstance(e, Num):
    return set([e.num])
  elif isinstance(e, BinCon) or isinstance(e, BinOp) or isinstance(e, Cmp):
    return expr_constants(e.left).union(expr_constants(e.right))
  elif isinstance(e, UnCon):
    return expr_constants(e.arg)
  return set([])

def val_str(val):
  s = ""
  for (v, n) in val:
    s += ", " + v + " = " + str(n) 
  return s[1:]

class DDS:

  def __init__(self, dds_as_array):
    self._name = dds_as_array["name"]
    # states is dictionary mapping state id to state array
    self._states = {}
    for s in dds_as_array["states"]:
      id = s["id"]
      self._states[id] = s
    self._variables = dds_as_array["variables"]
    self._transitions = dds_as_array["transitions"]
    self._has_lookback = False
    res, msg = self.sanity_check()
    if not res:
      print("error when processing input DDS:\n" + msg)
      exit(0)
  
  def name(self):
    return self._name
  
  def states(self):
    return self._states
  
  def transitions(self):
    return self._transitions
  
  def variables(self):
    return self._variables
  
  def initial_states(self):
    return [ s for s in self._states.values() if s["initial"] ]

  def final_state_ids(self):
    return [ s["id"] for s in self._states.values() if s["final"] ]

  # given state id, return outgoing transitions
  def out(self, sid):
    return [t for t in self._transitions if t["source"] == sid]

  # given state id, return incoming transitions
  def ins(self, sid):
    return [t for t in self._transitions if t["target"] == sid]

  def set_lookback_mode(self):
    self._has_lookback = True

  def has_lookback(self):
    return self._has_lookback

  # checks whether source or target of transition is undefined, and variables
  # in guards are declared
  def sanity_check(self):
    trans = self._transitions
    is_state = lambda i: i in self._states
    if not all( is_state(t["source"]) and is_state(t["target"]) for t in trans):
      return False, "source or target of transition undefined"
    var_names = [v["name"] for v in self._variables]
    vars_ok = lambda e: e.basevars().issubset(var_names)
    if not all( vars_ok(t["guard"]) for t in trans):
      for t in trans:
        if not vars_ok(t["guard"]):
          print(t["guard"], t["guard"].basevars(), var_names)
      return False, "variables in transition undefined"
    return True, ""

  def reachable_from(self, states):
    reach = set(states)
    while len(states) > 0:
      statesx = []
      for s in states:
        nexts = [ t["target"] for t in self.out(s) if not t["target"] in reach ]
        reach = reach.union(set(nexts))
        statesx += nexts
      states = statesx
    return reach

  def is_mc_system(self, cs):
    vt = all(v["type"] in [VarType.rat,VarType.real] for v in self._variables)
    csx = [ t["guard"] for t in self._transitions ] + cs
    csx = [ c for constr in csx for c in conjuncts(constr) if c != top ]
    mcs = all(c.is_mc() for c in csx)
    return vt and mcs
  
  def is_gc_system(self, cs):
    csx = [ t["guard"] for t in self._transitions ] + cs
    csx = [ c for constr in csx for c in conjuncts(constr) if c != top ]

    is_var = lambda e: isinstance(e, Var)
    is_num = lambda e: isinstance(e, Num)
    def is_var_diff(e):
      return isinstance(e, BinOp) and e.op == "-" and is_var(e.left) and \
        is_var(e.right)

    def is_gc(c):
      if not isinstance(c, Cmp):
        return False
      l = c.left
      r = c.right
      if (isinstance(r, Num) and is_var_diff(l) and c.op == ">=") or \
        (isinstance(l, Num) and is_var_diff(r) and c.op == "<="):
        return True
      elif (is_var(l) and (is_var(r) or is_num(r))) or \
        (is_var(r) and (is_var(l) or is_num(l))):
        return True
      return False

    def cutoff(gs):
      consts = [ int(n) for c in gs for n in expr_constants(c)]
      return max(consts + [0])

    return (True, cutoff(csx)) if all(is_gc(g) for g in csx) else (False, None)

  def init_values(self, solver):
    def val(v):
      v0 = v["initial"]
      return Bool(v0) if v["type"] == VarType.bool else Num(v0)
    return [ val(v).toSMT(solver, []) for v in self._variables ]

  def init_val_constraints(self, solver, smtvars):
    # smtvars assumed to be in order of self._variables
    varvals = zip(smtvars, self.init_values(solver))
    return [solver.eq(var, val) for (var, val) in varvals]

  def has_bounded_lookback(self, cs, solver, k):
    vs = self._variables
    vs_names = [v["name"] for v in self._variables]
    vdeps0 = dict([ (v["name"],0) for v in vs ])

    # var v has currently a lookback of i steps; update using guards in context
    def ext_path(ctx, v, i, vwritten):
      (gs, vdeps) = ctx
      # v is written without dependencies: return vdeps0
      if v in vwritten and not (any(v in [x.name for x in g.vars()] for g in gs)):
        return 0
      if self.has_lookback:
        writes_v = lambda g: v in [ v.name for v in g.vars() if not v.is_back ]
      else:
        writes_v = lambda g: v in [ v.name for v in g.vars() if v.is_prime ]
      if all (not writes_v(g) for g in gs):
        return vdeps[v]
      # otherwise v gets written 
      def max_dep(g):
        is_eqlit = isinstance(g,Cmp) and g.op == "==" and \
          isinstance(g.left, Var) and isinstance(g.right, Var) # special case of v == y
        if self.has_lookback:
          rvars = [ v.name for v in g.vars() if v.is_back ]
        else:
          rvars = [ v.name for v in g.vars() if not v.is_prime ]
        if is_eqlit and len(rvars) > 0: # have shape v^w = y^r
          return vdeps[rvars[0]]
        return max([vdeps[x] for x in rvars]) + 1 if len(rvars) > 0 else 0
      
      m = max([ max_dep(g) for g in gs if writes_v(g)])
      return m

    n = len(self._states)
    bound = n + k * n
    brange = range(0, bound + 1)

    # FIXME similar code as in witness search
    var = lambda v,i: mk_var(solver, Var.from_array(v), suffix=str(i))
    vars_stages = [ dict([ (v["name"], var(v,i)) for v in vs ]) for i in brange]
    # build formula phi combining all constraints
    # initialize with initial assignment
    vars0 = [vars_stages[0][x] for x in vs_names]
    phi = solver.land(self.init_val_constraints(solver, vars0))

    tiles = [(s["id"], vdeps0, phi) for s in self.initial_states()]
    visited = set([])
    # search for loop: depth: n to cover all states for length of stem, plus in
    # the worst case, k times n since an increase of path length by 1 requires
    # in the worst case to visit all states
    for i in range(1, bound + 1):
      tilesx = []
      if self.has_lookback():
        subst_back = [ (x + "-", v) for (x, v) in vars_stages[i-1].items() ]
        subst_guard = dict(list(vars_stages[i].items()) + subst_back)
      else:
        # lookahead or none
        subst_prime = [ (x + "'", v) for (x, v) in vars_stages[i].items() ]
        subst_guard = dict(list(vars_stages[i-1].items()) + subst_prime)


      for (s, vdeps, phi) in tiles:
        tup = (s, tuple([ vdeps[v] for v in vs_names ]))
        if tup in visited:
          continue
        visited.add(tup)
        
        for t in [t for t in self._transitions if t["source"] == s]:
          g = t["guard"]
          guard = g.toSMT(solver, subst_guard)
          inertia = [solver.eq(vars_stages[i-1][v], vars_stages[i][v]) \
            for v in vs_names if v not in t["written"] ]
          phix = solver.land([guard, phi] + inertia) # updated phi
          if solver.check_sat(phix) == None:
            continue # skip: this transition cannot be executed

          ctx = (conjuncts(g) + cs, vdeps)
          vdepsx = dict([ (v, ext_path(ctx, v, i, t["written"])) \
            for (v,i) in vdeps.items()])
          if any( p > k for (_,p) in vdepsx.items() ):
            return False
          tilesx.append((t["target"], vdepsx, phix))
      tiles = tilesx
    return True

  def project_constraint(self, constr, vs):
    cs = conjuncts(constr)
    vcs = [ c for c in cs if len(set(c.basevars()).intersection(vs)) > 0]
    return reduce(lambda e1,e2: BinCon(e1, "&&", e2), vcs, top)

  def project2vars(self, vs): # project DDS to subset of variables
    def project_guard(t):
      t_projected = dict(t)
      t_projected["guard"] = self.project_constraint(t["guard"], vs)
      return t_projected

    ts = [ project_guard(t) for t in self._transitions ]
    dds = {
      "name": self._name + " vars projection " + str(vs),
      "variables": [v for v in self._variables if v["name"] in vs],
      "states": self._states.values(),
      "transitions": ts
    }
    return DDS(dds)

  def project2states(self, s): # project DDS to subset of variables
    ts = [t for t in self._transitions if t["source"] in s and t["target"] in s]
    vs = set([ v for t in ts for v in t["guard"].basevars() ])
    dds = {
      "name": self._name + " state projection " + str(s),
      "variables": [v for v in self._variables if v["name"] in vs],
      "states": [ dict(self._states[id]) for id in s ],
      "transitions": ts
    }
    return DDS(dds)

  def separated_var_subsets(self, cs):
    vs = dict([ (v["name"], set([v["name"]])) for v in self._variables ])
    gconjs = [ c for t in self._transitions for c in conjuncts(t["guard"]) ] + cs
    for c in gconjs:
      xs = reduce(lambda s,t: s.union(t), [vs[v] for v in c.basevars() if v in vs], set([]))
      for v in xs:
        vs[v] = xs
    
    partitioning = []
    for s in vs.values():
      if all( p != s for p in partitioning):
        partitioning.append(s)
    return partitioning

  def has_finite_summary(self, cs, solver, with_gc=True):
    msg = ""
    if self.is_mc_system(cs):
      msg += "MC system"
      return True, Equivalence(solver), msg
    else:
      ff_bnd = 2 * len(self._variables) + 1 # bound for feedback freedom
      if self.has_bounded_lookback(cs, solver, ff_bnd):
        msg += str(ff_bnd) +"-bounded lookback"
        return True, Equivalence(solver), msg
      elif with_gc:
        (is_gc, cutoff) = self.is_gc_system(cs)
        if is_gc:
          msg += "GC (cutoff " + str(cutoff) + " )"
          return True, GCEquivalence(solver, cutoff), msg
    
    # sequential splitting
    ss = list(self._states.values())
    for i in range(len(ss)-1, 1, -1):
      ss1 = [s["id"] for s in ss[0:i]]
      ss2 = [s["id"] for s in ss[i-1:len(ss)]]
      # do not cut on loop states
      id = ss2[0]
      breach1 = self.reachable_from([id]).difference(set([id]))
      breach2 = self.reachable_from(set(ss2).difference(set(ss1)))
      backreach = breach1.union(breach2).intersection(ss1)
      skipedges = [ t for t in self._transitions if \
        t["source"] in ss1 and t["source"] != id and
        t["target"] in ss2 and t["target"] != id ]
      if len(backreach) == 0 and len(skipedges) == 0:
        dds1 = self.project2states(ss1)
        dds2 = self.project2states(ss2)
        dds2._states[id]["initial"] = True
        (fs1, equiv, msg1) = dds1.has_finite_summary(cs, solver, with_gc = False)
        (fs2, _, msg2) = dds2.has_finite_summary(cs, solver, with_gc = False)
        if fs1 and fs2:
          msg += " " + self._name + " admits state-splitting:\n" + msg1 + msg2
          return True, equiv, msg # equivalence is logical equivalence anyway

    # variable splitting
    if len(self._variables) > 1:
      parts = self.separated_var_subsets(cs)
      if len(parts) > 1:
        vsplit_ddss = [ (self.project2vars(p), p) for p in parts ]
        fss = []
        for (dds, vs) in vsplit_ddss:
          cs_proj = [ self.project_constraint(c, vs) for c in cs]
          res = dds.has_finite_summary(cs_proj, solver, with_gc=with_gc)
          (fs, eq, m) = res
          fss.append(res)
        if all(r[0] for r in fss): # all have finite summary
          msg += " " + self._name + " admits variable-splitting:\n"
          for fs in fss:
            msg += " " + fs[2]
          equiv = DecompositionEquivalence(solver, parts, [fs[1] for fs in fss])
          return True, equiv, msg
    return False, None, ""

  def check_finite_summary(self, cs, solver, verbose, with_gc=True):
    (fin_summ, equiv, msg) = self.has_finite_summary(cs, solver,with_gc=with_gc)
    if verbose:
      if fin_summ:
        print("%s has finite summary:\n%s" % (self._name, msg))
      else:
        print("Finite summary not recognized.")
    return (fin_summ, equiv)

  def invert(self):
    var_inverter = VarFlipper()

    def flip_state(s):
      sinv = {"id": s["id"], "name" : s["name"], "initial": s["final"], \
        "final": s["initial"]}
      return sinv

    def flip_transition(t):
      t = deepcopy(t)
      src = t["source"]
      t["source"] = t["target"]
      t["target"] = src
      if len(t["written"]) > 0:
        t["guard"].accept(var_inverter)
      return t

    dds = {
      "name": self._name, 
      "states": [ flip_state(s) for s in self._states.values() ], 
      "transitions": [ flip_transition(t) for t in self._transitions ],
      "variables": self._variables}
    return DDS(dds)

  def hackstates(self, k):
    print("hack states", k)
    trans = [t for t in self._transitions if not (t["source"] == 5 and t["target"] == 6)]
    
    states = self._states
    lastid = 5
    for i in range(0, k):
      l = len(states)
      assert(not(l in states))
      states[l] = {"id": l, "name": "dummy" + str(l), "initial": False, "final": False}
      trans.append({"id": len(trans), "source": lastid, "target": l, \
        "name":"tdummy"+str(i), "guard":top, "written":[]})
      lastid = l

    trans.append({"id": len(trans), "source": lastid, "target": 6, \
      "name":"tdummy"+str(k), "guard":top, "written":[]})

    for t in trans:
      t["guard"] = str(t["guard"])

    vars = []
    for v in self._variables:
      v["type"] = "bool" if v["type"] == "bool" else VarType.to_str(v["type"])
      vars.append(v)

    dds = {
      "name": self._name + "_" + str(k), 
      "states": list(states.values()) , 
      "transitions": trans,
      "variables": vars}
    return dds

  def obfuscate(self, k, g):
    def obfuscate_cmp(c):
      t = c.left._type
      tstr = VarType.to_str(t)
      var = lambda i: Var("v"+ tstr + str(i)+"'", t)
      e = Cmp(c.op, c.left, var(0))
      if c.op == "<" and t == VarType.rat:
        op = "<"
      elif c.op == ">" and t == VarType.rat:
        op = ">"
      elif c.op == ">=":
        op = ">="
      elif c.op == "<=":
        op = "<="
      else:
        op = "=="
      for i in range(1, k+1):
        e = BinCon(e, "&&", Cmp(op, var(i-1), var(i)))
      return BinCon(e, "&&", Cmp(op, var(k), c.right))
    
    if isinstance(g, BinCon):
      return BinCon(self.obfuscate(k, g.left), g.op, g.right)
    elif isinstance(g, UnCon):
      return UnCon(g.op, self.obfuscate(k, g.arg))
    elif isinstance(g, Cmp):
      return obfuscate_cmp(g)
    elif isinstance(g, Bool):
      return g
    print(g)
    assert(False)

  def hackvars(self, k):
    trans = deepcopy(self._transitions)

    vars = []
    for v in self._variables:
      vv = deepcopy(v)
      vv["type"] = VarType.to_str(vv["type"])
      vars.append(vv)

    types = set([ v["type"] for v in self._variables])
    print(types)
    for t in types:
      tstr = VarType.to_str(t)
      for i in range(0, k+1):
        l = len(vars)
        vars.append({"name":"v"+ tstr + str(i), "type":tstr, "initial":0 if t != VarType.bool else False})
        lastid = l

    #i=1
    #for t in trans:
    #  e = t["guard"]
    #  if k > 1:
    #    e = BinCon(e, "&&", Cmp("==", Var("var"+str(i)+"'", VarType.rat), Var("var"+str(i-1), VarType.rat)))
    #    e = BinCon(e, "&&", Cmp(">", Var("var"+str(i-1)+"'", VarType.rat), Num(0)))
    #    i = (i % (k-1)) + 1
    #  t["guard"] = str(e)
    for t in trans:
      t["guard"] = str(self.obfuscate(k, t["guard"]))
        

    dds = {
      "name": self._name + "_" + str(k), 
      "states": list(self._states.values()) , 
      "transitions": trans,
      "variables": vars}
    return dds

  def show_run(self, start_state, run):
    val0 = [ (v["name"], v["initial"]) for v in self._variables ]
    print(start_state["name"] +" (" + val_str(val0) +")")
    for (a, b, val) in run:
      print("  - [" + a + "] ->  " + b["name"] + " (" + val_str(val) + " ) ")

  def show(self, filename):
    g = nx.MultiDiGraph()
    g.add_nodes_from(self._states.keys())
    g.add_edges_from([(t["source"],t["target"]) for t in self._transitions])
    pos = nx.spring_layout(g)
    n = nx.draw_networkx_nodes(g, pos, node_size = 50)
    e = nx.draw_networkx_edges(g, pos, arrows=True)

    A = to_agraph(g) 
    A.node_attr['fontsize'] = "11"
    A.edge_attr['fontsize'] = "10"
    A.node_attr['fontname'] = "Arial"
    A.edge_attr['fontname'] = "Arial"
    A.edge_attr['arrowsize'] = "0.6"

    st = {}
    for t in self._transitions:
        (src,tgt) = (t["source"],t["target"])
        st[(src,tgt)] = 0 if (src,tgt) not in st else st[(src,tgt)] + 1
        edge = A.get_edge(t["source"],t["target"],st[(src,tgt)])
        g = expr_replace(str(t["guard"]))
        edge.attr['label'] = t["name"] + (": " + g if g != "True" else "")
        if t["source"] == t["target"]:
          edge.attr['tailport'] = "ne" 
          edge.attr['headport'] = "se"
    for (i,s) in self._states.items():
        n = A.get_node(i)
        n.attr['shape']='box'
        n.attr['margin']="0.1,0.005"
        n.attr['label'] = s["name"]
        n.attr['height']="0.3"
        if s["final"]:
          n.attr['style']='filled'
          n.attr['fillcolor']='gray'

    A.layout('dot') 
    A.draw(filename)  
