from sys import maxsize
from functools import reduce
import json

from dds.expr import top, Var, Cmp, Num, BinCon
from dds.dds import DDS
from dds.read import VarType

class DPN:

  def __init__(self, dpn_as_array):
    dpn_as_array, map = self.mk_ids_integer(dpn_as_array)
    self.net = dpn_as_array
    self._places = dpn_as_array["places"]
    self._transitions = dpn_as_array["transitions"]
    self._variables = dpn_as_array["variables"]
    self._arcs = dpn_as_array["arcs"]
    #self.add_silent_finals(map)
    #self.compute_shortest_paths()
    self.has1token = None
  
  def places(self):
    return self._places
  
  def transitions(self):
    return self._transitions
  
  def arcs(self):
    return self._arcs
  
  def variables(self):
    return self._variables

  # replace ids of places and transitions by unique integers
  def mk_ids_integer(self, dpn):
    id = 0
    str2int = {}
    int2plc = {}
    for p in dpn["places"]:
      n = p["id"]
      int2plc[id] = dict(p) 
      p["id"] = id
      str2int[n] = id
      id += 1
    for p in dpn["transitions"]:
      n = p["id"]
      int2plc[id] = dict(p)
      p["id"] = id
      str2int[n] = id
      id += 1
    for a in dpn["arcs"]:
      a["source"] = str2int[a["source"]]
      a["target"] = str2int[a["target"]]
    return dpn, int2plc

  def to_DDS1(self):
    dtrans = dict([ (t["id"], t) for t in self._transitions ])
    states = []
    trans = []

    vars = {}
    for v in self._variables:
      vars[v["name"]] = {
        "name": v["name"],
        "initial": v["initialValue"],
        "type": v["type"]
      }

    for p in self._places:
      s = {
        "id": p["id"],
        "name": p["name"],
        "initial": "initial" in p,
        "final": "final" in p
        }
      states.append(s)
      
      tnext = [ a["target"] for a in self._arcs if a["source"] == p["id"]]
      for tid in tnext:
        pnext = [ a["target"] for a in self._arcs if a["source"] == tid]
        t = dtrans[tid]
        #writes = [ Cmp(">=", Var(v, vars[v]["type"], prime=True), Num("0")) \
        #  for v in t["written"] ]
        guard = t["constraint"] if "constraint" in t else top
        #guardx = reduce(lambda c, a: BinCon(c, "&&", a), writes, guard)
        for px in pnext:
          t = {
            "id": len(trans),
            "source": p["id"],
            "target": px,
            "name": t["label"],
            "guard": guard,
            "written": t["written"]
          }
          trans.append(t)

    dds_array = {
      "name": self.net["name"],
      "states": states,
      "variables": vars.values(),
      "transitions": trans
    }
    #print(json.dumps(dds_array, indent=2))
    return DDS(dds_array)
  

  def to_DDSn(self):
    def pre(t):
      return [a["source"] for a in self._arcs if a["target"] == t]

    def post(t):
      return [a["target"] for a in self._arcs if a["source"] == t]

    def enabled(t, marking):
      return len(pre(t)) > 0 and all(p in marking for p in pre(t))
    
    def get_mname(m):
      names = [self._places[p]["name"] for p in sorted(m)]
      return reduce(lambda s,n: s + n, names, "")

    trans = dict([ (t["id"], t) for t in self._transitions ])
    places = self._places
    inits = [p["id"] for p in places if "initial" in p and p["initial"]]
    markings = [set([p]) for p in inits]
    mdict = dict([ (get_mname(m), (m, i)) for (i, m) in enumerate(markings)])
    edges = []

    while len(markings) > 0:
      markings_new = []
      for m in markings:
        m_name = get_mname(m)
        mid = mdict[m_name][1]
        for t in [ t for t in trans.keys() if enabled(t, m) ]:
          m_new = m.difference(pre(t)).union(post(t))
          m_new_name = get_mname(m_new)
          if not m_new_name in mdict:
            k = len(mdict)
            markings_new.append(m_new)
            mdict[m_new_name] = (m_new, k)
          else:
            k = mdict[m_new_name][1]
          edges.append((mid, k, t))
      markings = markings_new

    states = []
    transitions = []
    for (name, (m, id)) in mdict.items():
      ps = [p for pid in m for p in places if p["id"] == pid]
      is_final = all("final" in p for p in ps)
      s = {
        "id": id,
        "name": name,
        "initial": id < len(inits),
        "final": is_final
        }
      # superfinal ist superset of final marking (unsoundness!)
      if any("final" in p for p in ps) and not is_final:
        s["superfinal"] = True
      states.append(s)

    vars = []
    vdict = {}
    for v in self._variables:
      var = {
        "name": v["name"],
        "initial": v["initialValue"],
        "type": v["type"]
      }
      vars.append(var)
      vdict[v["name"]] = var
      
    for (src, tgt, tid) in edges:
      t = trans[tid]
      guard = t["constraint"] if "constraint" in t else top
      t = {
        "id": len(transitions),
        "source": src,
        "target": tgt,
        "name": t["label"],
        "guard": guard,
        "written": t["written"]
      }
      transitions.append(t)

    dds_array = {
      "name": self.net["name"],
      "states": states,
      "variables": vars,
      "transitions": transitions
    }
    # print(json.dumps(dds_array, indent=2))
    return DDS(dds_array)

  def to_DDS(self):
    return self.to_DDS1() if self.has_single_token() else self.to_DDSn()

  ### add silent transition to one final place (without label and constraint)
  def add_silent_finals(self, map):
    id = len(map) + 1
    for p in self._places:
      if "final" in p:
        t = {"id": id, "invisible": True, "label":None }
        self._transitions.append(t)
        self._arcs.append({"source": p["id"], "target": id})
        self._arcs.append({"target": p["id"], "source": id})
        map[id] = t
        id += 1
        break
  
  def is_acyclic(self, pid):
    ps = [pid]
    visited = set(ps)
    (src, tgt) = ("source", "target")
    while len(ps) > 0:
      trans = [ l[tgt] for l in self._arcs if l[src] in ps ]
      psx = set([ l[tgt] for l in self._arcs if l[src] in trans])
      if pid in psx:
        return False
      elif psx.issubset(visited):
        return True
      ps = psx
      visited = visited.union(ps)
    return True

  # Parameter goals is a list of places.
  # Returns length of shortest path to some place in goals, starting from either
  # the initial places (if forward=True) or the final places (forward=False).
  def shortest_to(self, goals, forward = True):
    arcs = self._arcs
    (src, tgt) = ("source", "target") if forward else ("target", "source")
    def shortest(n, ns):
      if n["id"] in goals:
        return 0
      elif n in ns:
        return 1000 # no reachable goal state: the hack to infinity
      else:
        trans = [ l[tgt] for l in arcs if l[src] == n["id"] ]
        next_places = [ l[tgt] for l in arcs if l[src] in trans ]
        return 1 + min([shortest(places[p], [n]+ns) for p in next_places] +[1000])
    
    places = dict([ (p["id"], p) for p in self._places ])
    if forward:
      start = [p for p in places.values() if "initial" in p and p["initial"]]
    else:
      start = [p for p in places.values() if "final" in p and p["final"]]
    return min([ shortest(p, []) for p in start ])

  def shortest_accepted(self):
    finals = [ p["id"] for p in self._places if "final" in p and p["final"]]
    l = self.shortest_to(finals)
    return l if self.has_single_token() else 6 # FIXME

  # for every transition, compute the minimal distance to an initial/final place
  def compute_shortest_paths(self):
    for t in self._transitions:
      srcs = [ l["source"] for l in self._arcs if l["target"] == t["id"] ]
      t["idist"] = self.shortest_to(srcs) # min distance to initial
      tgts = [ l["target"] for l in self._arcs if l["source"] == t["id"] ]
      t["fdist"] = self.shortest_to(tgts, forward=False) # min distance to final

  # assumes one-boundedness
  def simulate_markings(self, num_steps):
    (src, tgt) = ("source", "target")
    transs = dict([ (t["id"], t) for t in self._transitions ])
    idists = dict([ (t["id"], t["idist"]) for t in self.transitions()])
    fdists = dict([ (t["id"], t["fdist"]) for t in self.transitions()])

    ps = [ p["id"] for p in self._places if "initial" in p ]
    states = [ (set(ps),[]) ] # pairs of current marking and transition history
    seen_acylic = set([])
    for i in range(0, num_steps):
      if i > 12 or len(states) > 22: # gets too expensive
        ts = [ t for (id, t) in transs.items() if fdists[id] < rem and i >= idists[id] ]
        seen_acylic_sub = [tid for tid in seen_acylic if tid not in [t["id"] for t in self._reachable[i-1]]]
        ts_sub = [t for t in ts if not t["id"] in seen_acylic_sub]
        self._reachable.append(ts_sub)
      else:
        statesx = []
        self._reachable.append([])
        rem = num_steps - i
        for (marking, steps) in states:
          for p in marking:
            ts = [ l[tgt] for l in self._arcs if l[src] == p ]
            for t in ts:
              post_t = [ a[tgt] for a in self._arcs if a[src] == t]
              pre_t = [ a[src] for a in self._arcs if a[tgt] == t]
              if not set(pre_t).issubset(marking):
                continue # this transition is not enabled, skip
              markingx = marking.difference(pre_t).union(post_t)
              statesx.append((markingx, steps + [t]))
              if not transs[t] in self._reachable[i] and fdists[t] < rem:
                self._reachable[i].append(transs[t])
                if self.is_acyclic(t):
                  seen_acylic = seen_acylic.union({t})
      states = statesx


  def compute_reachable(self, num_steps):
    self._reachable = []
    
    if self.has_single_token(): 
      fdists = dict([ (t["id"], t["fdist"]) for t in self.transitions()])
      transs = dict([ (t["id"], t) for t in self._transitions ])
      (src, tgt) = ("source", "target")
      ps = [ p["id"] for p in self._places if "initial" in p ]
      for i in range(0, num_steps):
        rem = num_steps - i
        ts = [ l[tgt] for l in self._arcs if l[src] in ps ]
        self._reachable.append([transs[t] for t in set(ts) if fdists[t] < rem])
        ps = set([ a[tgt] for a in self._arcs if a[src] in ts])
    else:
      self.simulate_markings(num_steps)
      
  
  # set of transitions reachable within i steps
  def reachable(self, i):
    return self._reachable[i]
  
  # compute minimal number of steps needed before variable is written
  def var_write_reach(self):
    vreach = []
    for i in range(0,len(self._reachable)):
      vs = [v for t in self._reachable[i] if "written" in t for v in t["written"]]
      vreach.append(list(set(vs)))
    return vreach

  def has_single_token(self):
    if self.has1token:
      return self.has1token
    
    for p in self.places():
      if "initial" in p and p["initial"] > 1:
        self.has1token = False
        return False
        
    for t in self._transitions:
      outs = [ a for a in self._arcs if a["source"] == t["id"]]
      if len(outs) > 1:
        self.has1token = False
        return False
    self.has1token = True
    return True
