import pyparsing as pyp
import json
from xml.dom import minidom

from dds.util import VarType
from dds.expr import Expr, Num, Var, Charstr, Cmp, BinOp, UnCon, BinCon, \
  PropVar, Bool, top

def unescape(s):
  s = s.replace("&lt;", "<")
  s = s.replace("&gt;", ">")
  # this has to be last:
  s = s.replace("&amp;", "&")
  return s

pyp.ParserElement.enablePackrat()

def mkVar(vars, toks):
  name = toks[0]
  if name in vars:
    prime = toks[1] if len(toks) > 1 else None
    vtype = vars[name]
    return Var(name, vtype, prime)
  else:
    raise NameError("variable " + name + " is not declared in DDS")

### parsing stuff
def parse_prop_atom(vtypes):
  vnames = list(vtypes.keys())
  LPAR = pyp.Literal('(').suppress()
  RPAR = pyp.Literal(')').suppress()
  quote = pyp.Literal('"').suppress()
  sp = pyp.OneOrMore(pyp.White()).suppress()
  sps = pyp.ZeroOrMore(pyp.White()).suppress()
  nums = pyp.Word(pyp.srange("[0-9]"))
  num = (nums + pyp.Optional(pyp.Literal('.') + nums))\
    .setParseAction(lambda toks: Num(''.join(toks)))
  #var = (pyp.Word(pyp.alphas.lower(), pyp.srange("[a-zA-Z0-9]")) + pyp.Optional(pyp.Literal("'"))).\
  #  setParseAction(lambda toks: mkVar(vtypes, toks))
  var = (pyp.oneOf(vnames) + pyp.Optional(pyp.Literal("'"))).\
    setParseAction(lambda toks: mkVar(vtypes, toks))
  chars = (pyp.QuotedString('"')).setParseAction(lambda toks: Charstr(toks[0]))
  boolean = (pyp.oneOf("True False true false")).setParseAction(lambda toks: Bool(toks[0]))
  term = pyp.Forward()
  pterm = (LPAR + sps + term + sps + RPAR).setParseAction(lambda toks: toks[0])
  term << pyp.infixNotation(num | var | pterm | boolean | chars, [ # boolean is here to accept stuff like vip == True
        (pyp.Literal('*'), 2, pyp.opAssoc.LEFT, lambda ts: BinOp(ts[0][0], ts[0][1], ts[0][2])),
        (pyp.oneOf("+ -"), 2, pyp.opAssoc.LEFT, lambda ts: BinOp(ts[0][0], ts[0][1], ts[0][2]))
    ])

  formula = pyp.Forward()
  cmpop = pyp.oneOf("== < > <= >= !=")
  atom = (sps + term + sps + cmpop + sps + term + sps).\
     setParseAction(lambda toks: Cmp(toks[1], toks[0], toks[2]))
  patom = (LPAR + sps + atom + sps + RPAR).setParseAction(lambda toks: toks[0])
  return patom | boolean

def mk_bin(ts):
  items = ts[0][1:]
  expr = ts[0][0]
  while len(items) > 0:
    expr = BinCon(expr, items[0], items[1])
    items = items[2:]
  return expr

def parse_expr(s, vars):
  #print("parsing " + s)
  vtypes = dict([ (v["name"], v["type"]) for v in vars ])
  patom = parse_prop_atom(vtypes)
  formula = pyp.Forward()
  formula << pyp.infixNotation(patom, [
        (pyp.oneOf("&& ||"), 2, pyp.opAssoc.LEFT, mk_bin),
    ])
  res = formula.parseString(s)
  r = res[0] if len(res) > 0 else None
  return r


def read_dpn_json(jsonfile):
  file = open(jsonfile, "r")
  content = file.read()
  input = json.loads(content)
  for t in input["model"]["transitions"]:
    if "condition" in t:
      t["constraint"] = parse_expr(t["condition"], input["variables"])
  return input

def base_var(name):
  return name[:-1] if name[-1] == '\'' else name

def read_dpn_pnml(pnmlfile):
  dom = minidom.parse(pnmlfile)
  dpn = {
    "variables": [],
    "places": [],
    "transitions": [],
    "arcs": [],
    "name": ""
  }
    
  # arcs
  for a in dom.getElementsByTagName('arc'):
    src = a.getAttribute('source')
    tgt = a.getAttribute('target')
    id = a.getAttribute('id')
    # arctype = a.getElementsByTagName('arctype')[0]
    # t = arctype.getElementsByTagName('text')[0].firstChild.nodeValue
    dpn["arcs"].append({ "source": src, "target": tgt, "id": id })
  
  net = dom.getElementsByTagName('net')[0]
  n = net.getElementsByTagName('name')[0].getElementsByTagName('text')[0]
  dpn["name"] = n.firstChild.nodeValue

  # variables
  varlist = dom.getElementsByTagName('variable')
  # determine variables used in guards
  #guard_vars = set([])
  #for t in dpn["transitions"]:
  #  if "constraint" in t:
  #    guard_vars = guard_vars.union([base_var(v) for v in t["constraint"].vars()])
  
  for v in varlist:
    name = v.getElementsByTagName('name')[0].firstChild.nodeValue
    if True: #name in guard_vars:
      vtype = VarType.from_java(v.getAttribute('type'))
      vinit = v.getAttribute('initial')
      vinit = vinit if len(vinit) > 0 else False if vtype == VarType.bool else 0
      var = {"name": name, "initialValue": vinit, "type": vtype}
      dpn["variables"].append(var)

  # transitions
  vs = [ v["name"] for v in dpn["variables"] ]
  for a in dom.getElementsByTagName('transition'):
    id = a.getAttribute('id')
    inv = a.getAttribute('invisible')
    inv = True if inv == 'true' else False
    guard = unescape(a.getAttribute('guard'))
    n = a.getElementsByTagName('name')[0]
    nameval = n.getElementsByTagName('text')[0].firstChild.nodeValue
    ws = [w.firstChild.nodeValue for w in a.getElementsByTagName('writeVariable')]
    wsx = list(set([v for v in vs if v+"'" in guard ] + ws)) if guard else ws
    t = { "id": id, "label": nameval, "written": wsx, "invisible": inv}
    if guard:
      t["constraint"] = parse_expr(guard, dpn["variables"])
    dpn["transitions"].append(t)

  # places
  for a in dom.getElementsByTagName('page')[0].getElementsByTagName('place'):
    id = a.getAttribute('id')
    p = { "id": id }
    name = a.getElementsByTagName('name')
    if name:
      p["name"] = name[0].getElementsByTagName('text')[0].firstChild.nodeValue
    final = a.getElementsByTagName('finalMarking')
    if len(final) > 0:
      p["final"] = int(final[0].getElementsByTagName('text')[0].firstChild.nodeValue)
    init = a.getElementsByTagName('initialMarking')
    if len(init) > 0:
      p["initial"] = int(init[0].getElementsByTagName('text')[0].firstChild.nodeValue)
    dpn["places"].append(p)
  
  for t in dpn["transitions"]:
    if "guard" in t:
      guard = t["guard"]
      assert(set([v for v in vs if v+"'" in guard ]).issubset(set(t["write"])))

  return dpn

def read_properties_pnml(pnmlfile):
  dom = minidom.parse(pnmlfile)
  p = {"tests": []}
  for a in dom.getElementsByTagName('property'):
    s = a.getAttribute('sound')
    if s != None:
      p["sound"] = (s == "1" or s == "true")
    s = a.getAttribute('test')
    if len(s) > 0:
      p["tests"].append(s)
  return p
    

def read_dds(jsonfile):
  file = open(jsonfile, "r")
  content = file.read()
  input = json.loads(content)
  for t in input["states"]:
    if not "initial" in t:
      t["initial"] = False
    if not "final" in t:
      t["final"] = False

  for v in input["variables"]:
    v["type"] = VarType.from_str(v["type"])

  vs = [v["name"] for v in input["variables"]]
  tcount = 0
  for t in input["transitions"]:
    if "guard" in t:
      tw = t["written"] if "written" in t else []
      t["written"] = list(set([v for v in vs if v+"'" in t["guard"]] + tw))
      t["guard"] = parse_expr(t["guard"], input["variables"])
    else:
      t["written"] = t["written"] if "written" in t else []
      t["guard"] = top
    t["id"] = tcount
    tcount += 1
  return input
