
from dds.expr import Cmp, PropVar, Bool
from verification.abstraction_equivalence import *

class ConfigMap:

  _map = {}
  _dds = None
  _smt_vars = None

  # initialize with base cases where formula is constraint, state, or constant
  def __init__(self, dds, solver, smt_vs, param):
    self._dds = dds
    self._solver = solver
    self._smt_vars = smt_vs
    if isinstance(param, dict):
      self._map = param
    else:
      if isinstance(param, PropVar) and  \
        param.name in [ s["name"] for s in dds._states.values() ]:
        bot = solver.false()
        top = solver.true()
        self._map = dict([ (id, top if s["name"] == param.name else bot) \
          for (id, s) in dds._states.items() ])
      else:
        assert(isinstance(param, Cmp) or isinstance(param, Bool) or \
          isinstance(param, PropVar))
        # set constant map that assigns formula to every state
        smt_formula = param.toSMT(solver, self._smt_vars)
        self._map = dict([ (id, smt_formula) for id in dds._states ])

  def new_from_map(self, m):
    return ConfigMap(self._dds, self._solver, self._smt_vars, m)
  
  def get(self, k):
    return self._map[k]

  def land(self, other):
    m = [(k, self._solver.land([self.get(k), other.get(k)])) for k in self._map]
    return self.new_from_map(dict(m))

  def lor(self, other):
    m = [(k, self._solver.lor([self.get(k), other.get(k)])) for k in self._map]
    return self.new_from_map(dict(m))

  def negate(self):
    m = [(k, self._solver.neg(self.get(k))) for k in self._map]
    return self.new_from_map(dict(m))
  
  def simplify(self):
    eq = Equivalence(self._solver)
    for (k,v) in self._map.items():
      if not (isinstance(v, bool)):
        e = self._solver.simplify_more(v)
        e = eq.simp_dnf(e)
        self._map[k] = e
  
  def __str__(self):
    keys = list(self._map.keys())
    keys.sort()
    s = ""
    for k in keys:
      s = s + " " + str(k) + ": " + str(self._map[k]) + ","
    return "{" + s[:-1] + " }"

  def __hash__(self):
    l = [ (k, hash(v)) for (k,v) in self._map.items()]
    return hash(tuple(l))