import json
import re
import warnings
import sys
import pickle
from os.path import isfile

from ltl.automaton import Automaton
from dds.util import VarType
from dds.expr import mk_conj, mk_disj, mk_var, mk_const, Var,rename_and_quantify
from verification.constraint_graph import ConstraintGraph

class Trace:
  def __init__(self, trace):
    if isinstance(trace, str):
      (self._trace, self._variables) = self.from_str(trace)
    else:
      self._trace = trace["trace"]
      self._variables = dict((v["name"], VarType.from_str(v["type"])) \
        for v in trace["variables"])
  
  def from_str(self, s):
    # format [(x=2,y=3),(x=0, y=3)]
    trace = []
    variable_types = {}
    m = re.search("\(([^\)]*)\)", s)
    while m:
      assignment = {}
      assignment_str = m.groups()[0]
      for assign in assignment_str.split(","):
        assert("=" in assign)
        parts = assign.split("=")
        var = parts[0].strip()
        val = parts[1].strip()
        is_real = "." in val
        if var not in variable_types:
          variable_types[var] = VarType.real if is_real else VarType.int
        assignment[var] = float(val) if is_real else int(val)
      trace.append(assignment)
      s = s[s.find(")")+1:]
      m = re.search("\(([^\)]*)\)", s)
    return (trace, variable_types)
  
  def get_variables(self):
    return self._variables

  def restrict_to(self, vset):
    vnames = [ v["name"] for v in vset]
    def restr(assign):
      a = {}
      for (k,v) in assign.items():
        if k in vnames:
          a[k] = v
      return a

    self._trace = [restr(assign) for assign in self._trace]
  
  def __str__(self):
    return str(self._trace)

  def __iter__(self):
    for a in self._trace:
      yield a

  def __len__(self):
    return len(self._trace)

  def __getitem__(self, i):
    #FIXME type check on index?
    return self._trace[i]
    


class Monitor:
  def __init__(self, property, solver, suffix, verbose = False):
    self._suffix = suffix
    self._solver = solver
    precomputed = self.load_monitor() if suffix else None
    if precomputed:
      (precomputed_dfa, precomputed_formulas) = precomputed
      self._dfa = precomputed_dfa
      #FIXME reloaded formulas seem to work with Z3. maybe not other solvers.
      self._future_formulas = precomputed_formulas
      # print("precomputed")
    else:
      property.shift_to_lookback()
      if verbose:
        print("property shifted to lookback: " + str(property))
      nfa = Automaton(property, solver=solver, deterministic = True)
      self._dfa = nfa.determinize()
      self._future_formulas = {}
      if verbose:
        self.visualize_base_graphs(nfa, self._dfa, suffix)
    self._dds = self._dfa.to_dds()
    self._dds.set_lookback_mode()

    (has_finsumm, equiv, m) = self._dds.has_finite_summary([], solver)
    self._equiv = equiv
    if verbose:
      print("finite summary detected: " + str(has_finsumm) + " (" + m + ")")
    self._cgs = {}
    self.mk_vars()


  def is_prev_var(self, name):
    return name[-1] == '-'
  
  def mk_prev_var(self, name):
    return name + '-'

  def mk_vars(self):
    solver = self._solver
    self._smt_vars = dict([ (v["name"], mk_const(solver, Var.from_array(v))) \
      for v in self._dds.variables() ])
    self._smt_vars_copy = dict([ (v["name"]+"0", \
      mk_var(solver,Var.from_array(v), suffix="0")) \
      for v in self._dds.variables() ])

    self._smt_vars_prev = {}
    for v in self._dds.variables():
      prev_name = self.mk_prev_var(v["name"])
      vv = {"name": prev_name, "type":v["type"] }
      self._smt_vars_prev[prev_name] = mk_var(solver, Var.from_array(vv))

  def mk_value(self, v, val):
    vars = dict([(v["name"], v) for v in self._dds.variables()])
    if not v in vars:
      return self._solver.num(0) # just to avoid bailout if variable not used
    if vars[v]["type"] == VarType.int:
      return self._solver.num(val)
    elif vars[v]["type"] in [VarType.rat, VarType.real]:
      return self._solver.real(val)
    else:
      print("unsupported type for trace checking: " + v)

  def evaluate(self, formula, ass_prev, ass_curr):
    # formula is SMT formula
    if ass_prev:
      subst1 = [ (self.mk_prev_var(x), self.mk_value(x, val)) \
        for (x, val) in ass_prev.items()]
    else:
      subst1 = []
    subst2 = [ (x, self.mk_value(x, val)) for (x, val) in ass_curr.items()]
    substdict = dict(subst1 + subst2)
    keys = list(substdict.keys())
    vars = [self._smt_vars_prev[k] if self.is_prev_var(k) else self._smt_vars[k] \
      for k in keys]
    vals = [ substdict[k] for k in keys]
    # print("subst", vars, vals, type(vars[0]), type(vars[1]), type(vals[0]), type(vals[1]))
    formula2 = self._solver.subst(vars, vals, formula)
    return self._solver.check_sat(formula2) != None

  def dfa_state_for_trace(self, trace):
    state = self._dfa.initial_state_id()
    subst_prev = None

    for (i, assign) in enumerate(trace):
      subst = [ (x, self.mk_value(x, val)) for (x, val) in assign.items()]
      substdict = dict(subst) if not subst_prev else \
        dict(subst + [ (self.mk_prev_var(x), smtval) for (x, smtval) in subst_prev ])
      subst_prev = subst
      tnext = [ (t,ls) for ((s,t), ls) in self._dfa._edges.items() if s==state ]
      for (t, ls) in tnext:
        guard = mk_disj([ mk_conj(l) for l in ls])
        if self._solver.check_sat(guard.toSMT(self._solver, substdict)):
          state = t
          break
      
    return state

  def rv_state(self, current, future):
    if current and future:
      return "CS"
    elif current and not future:
      return "PS"
    elif not current and future:
      return "CV"
    else:
      return "PV"

  def monitoring_state(self, trace, verbose = True):
    # determine state reached by DFA and sat status
    trace.restrict_to(self._dds.variables()) # restrict to vars in prop
    state = self.dfa_state_for_trace(trace)
    current_sat = state in self._dfa.final_state_ids()
    if verbose:
      #print("reached DFA state", state, self._dfa._states[state])
      print("trace satisfies property:", current_sat)
    # build CG
    if state in self._future_formulas:
      future_sat = self._future_formulas[state]
    else:
      cg = ConstraintGraph(self._dds, self._equiv, \
        start_state = state, vars = (self._smt_vars, self._smt_vars_copy))
      self._cgs[state] = cg
      if verbose:
        self.visualize_cg(cg, state, self._suffix)
    
      finals = self._dds.final_state_ids()
      target_ids = finals if not current_sat else \
        [ s for s in self._dds._states if not s in finals ]
      exprs = [ n._expr for (i,n) in cg._nodes.items() if n._state in target_ids ]

      vs = [ Var.from_array(v) for v in self._dds.variables() ]
      formula = rename_and_quantify(self._solver, exprs, vs, \
            list(self._smt_vars.values()), list(self._smt_vars_copy.values()))
      self._future_formulas[state] = formula
      (ass_prev, ass_curr) = (trace[-2] if len(trace) > 1 else None, trace[-1])
      future_sat = self.evaluate(formula, ass_prev, ass_curr)
      self._future_formulas[state] = future_sat

      if self._suffix:
        self.store_monitor()
    
    rvstate = self.rv_state(current_sat, future_sat)
    if verbose:
      print("monitoring state: ", rvstate)
    return rvstate
  
  def store_monitor(self):
    filename = "out/dfa_" + self._suffix + ".out"
    with open(filename, "wb") as f:
      pickle.dump(self._dfa, f)
      f.close()
    filename = "out/futures_" + self._suffix + ".out"
    with open(filename, "wb") as f:
      pickle.dump(self._future_formulas, f)
      f.close()
  
  def load_monitor(self):
    dfa = None
    futures = None
    filename = "out/dfa_" + self._suffix + ".out"
    if isfile(filename):
      with open(filename, "rb") as f:
        dfa = pickle.load(f)
        f.close()
    filename = "out/futures_" + self._suffix + ".out"
    if isfile(filename):
      with open(filename, "rb") as f:
        futures = pickle.load(f)
        f.close()
    if dfa and futures:
      return (dfa, futures)
    else:
      return None

  def visualize_base_graphs(self, nfa, dfa, suffix):
    # ignore warnings that graph output was scaled
    warnings.filterwarnings("ignore", message = ".*graph is too large.*Scaling.*")

    outdir = "out"
    filesuffix = ("_" + suffix if suffix else "") + ".png"
    nfa.show(outdir + "/nfa" + filesuffix)
    dfa.show(outdir + "/dfa" + filesuffix)
    # dds.show(outdir + "/ddsa" + filesuffix)

  def visualize_cg(self, cg, s, suffix):
    # ignore warnings that graph output was scaled
    warnings.filterwarnings("ignore", message = ".*graph is too large.*Scaling.*")

    outdir = "out"
    filesuffix = ("_" + suffix if suffix else "") + ".png"
    cg.show(outdir + "/cg" + "_" + str(s) + filesuffix)