Thursday, January 30, 2014

NFA to DFA

I've started reading a theory of computation textbook in my spare time. In the first few chapters, I re-learnt the algorithm for converting a non-deterministic finite automaton to a deterministic finite automaton that recognises the same language. I say "relearnt" because I actually learnt the same algorithm 2 years ago at the National Computer Science Summer School.

2 years ago, implementing the algorithm was one of the most frustrating (to debug and to reason about) programming tasks I'd ever done. If you're not familiar with the algorithm, it essentially involves breadth-first searching over a graph -- except it's more like a metagraph because each node represents a subset of the states (i.e. nodes) in the NFA (which itself is a graph). So in this metagraph, $A \xrightarrow{c} B$ iff the $B$'s state subset is the union of all the states reachable by following a $c$-transition from any of the states in $A$'s state subset. There's some extra stuff to handle $\epsilon$-transitions too.

Debugging this algorithm was as hard as it sounds. Here's the code from that troublesome day.

EPS = '~'

class NFANode(object):
    def __init__(self, is_final):
        self.is_final = is_final
        self.edges = set()

    def add(self, toks, next):
        for t in toks:
            self.edges.add((t, next))

    def EC(self):
        nodes = {self}
        for tok, next in self.edges:
            if tok == EPS:
                nodes.update(next.EC())
        return nodes


class DFANode(object):
    def __init__(self):
        self.edges = []
        self.states = set()
        
    def is_final(self):
        return any(k.is_final for k in self.states)
    
    def EC(self):
        nodes = set()
        for nfaNode in self.states:
            nodes.add(nfaNode)
            for tok, next in nfaNode.edges:
                if tok == EPS:
                    nodes.update(next.EC())
        dfaNode = DFANode()
        dfaNode.states = nodes
        return dfaNode

    def M(self, queryTok):
        nodes = set()
        for nfaNode in self.states:
            for tok, next in nfaNode.edges:
                if tok == queryTok:
                    nodes.add(next)
        dfaNode = DFANode()
        dfaNode.states = nodes
        return dfaNode


class NFA(object):
    def __init__(self, start):
        self.start = start

    def get_toks(self):
        tok_set = set()
        self.accepts(EPS*1000, tok_set)
        return tok_set.difference({EPS})
    
    def accepts(self, query, all_toks=None):
        """
        >>> n1, n2, n3 = NFANode(True), NFANode(False), NFANode(False)
        >>> n1.add('b', n2)
        >>> n1.add(EPS, n3)
        >>> n2.add('a', n2)
        >>> n2.add('ab', n3) 
        >>> n3.add('a', n1)
        >>> nfa = NFA(n1)
        >>> assert nfa.accepts('aaa')
        >>> assert not nfa.accepts('bb')
        >>> assert nfa.accepts('abababababababababbababababa')
        >>> assert nfa.accepts('baa')
        >>> assert nfa.accepts('baaaaaa')
        >>> assert nfa.accepts('aa')
        >>> assert nfa.accepts('baba')
        >>> assert not nfa.accepts('baababaaaaaaaaaaaaaaaaaab')
        >>> assert nfa.accepts('baababaaaaaaaaaaaaaaaaaaba')
        """
        q = [(0, self.start)]
        seen = set()
        
        while q:            
            pos, node = q.pop(0)
            if (pos, node) in seen:
                continue
            seen.add((pos, node))

            if pos == len(query):
                if node.is_final:
                    return True
            else:
                for (tok, next) in node.edges:
                    if isinstance(all_toks, set):
                        all_toks.add(tok)
                    if tok == EPS or (tok == query[pos] and pos < len(query)):
                        q.append((pos+(tok!=EPS), next))
        return False


class DFA(NFA):
    def __init__(self, nfaToConvert):
        self.nfa = nfaToConvert
        self.all_toks = self.nfa.get_toks()
        self.start = DFANode()
        self.start.states.add(self.nfa.start)
        # mapping of states tuple to DFANode
        self.cache = {}
        self._convert()

    def _convert(self):
        self.start = self.start.EC()
        self.cache[tuple(self.start.states)] = self.start
        
        q = [self.start]
        seen = set()

        while q:
            states = q.pop(0)
            for tok in self.all_toks:
                newDFANode = self.retrieve(states.M(tok).EC())
                if newDFANode.states:
                    states.edges.append((tok, newDFANode))
                    if tuple(newDFANode.states) not in seen:
                        seen.add(tuple(newDFANode.states))
                        q.append(newDFANode)
    
    def retrieve(self, node):
        if tuple(node.states) not in self.cache:
            self.cache[tuple(node.states)] = node
        return self.cache[tuple(node.states)]

    def accepts(self, query):
        """
        >>> n1, n2, n3 = NFANode(True), NFANode(False), NFANode(False)
        >>> n1.add('b', n2)
        >>> n1.add(EPS, n3)
        >>> n2.add('a', n2)
        >>> n2.add('ab', n3) 
        >>> n3.add('a', n1)
        >>> nfa = NFA(n1)
        >>> dfa = DFA(nfa)
        >>> assert dfa.accepts('aaa')
        >>> assert not dfa.accepts('bb')
        >>> assert dfa.accepts('abababababababababbababababa')
        >>> assert dfa.accepts('baa')
        >>> assert dfa.accepts('baaaaaa')
        >>> assert dfa.accepts('aa')
        >>> assert dfa.accepts('baba')
        >>> assert not dfa.accepts('baababaaaaaaaaaaaaaaaaaab')
        >>> assert dfa.accepts('baababaaaaaaaaaaaaaaaaaaba')
        """
        q = [(0, self.start)]
        seen = set()

        while q:            
            pos, node = q.pop(0)
            if (pos, node) in seen:
                continue
            seen.add((pos, node))

            if pos == len(query):
                if node.is_final():
                    return True
            else:
                for (tok, next) in node.edges:
                    if tok == query[pos] and pos < len(query):
                        q.append((pos+(tok!=EPS), next))
        return False


if __name__ == '__main__':
    import doctest
    doctest.testmod()