Correct NFA construction

There was a bug in the way that I was converting regular expressions
to NFAs. I'm still not entirely sure what was going on, but I
re-visited the construction and made it follow the literature more
closely and it fixed the problem.
This commit is contained in:
John Doty 2024-08-24 09:24:29 -07:00
parent 30f7798719
commit 0c952e4905
2 changed files with 91 additions and 41 deletions

View file

@ -2169,8 +2169,7 @@ class NFAState:
@dataclasses.dataclass
class Re:
def to_nfa(self, start: NFAState) -> NFAState:
del start
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
raise NotImplementedError()
def __str__(self) -> str:
@ -2258,11 +2257,12 @@ class ReSet(Re):
def __invert__(self) -> "ReSet":
return self.invert()
def to_nfa(self, start: NFAState) -> NFAState:
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
start = NFAState()
end = NFAState()
for span in self.values:
start.add_edge(span, end)
return end
return (start, [end])
def __str__(self) -> str:
if len(self.values) == 1:
@ -2285,10 +2285,14 @@ class ReSet(Re):
class RePlus(Re):
child: Re
def to_nfa(self, start: NFAState) -> NFAState:
end = self.child.to_nfa(start)
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
start, ends = self.child.to_nfa()
end = NFAState()
for e in ends:
e.epsilons.append(end)
end.epsilons.append(start)
return end
return (start, [end])
def __str__(self) -> str:
return f"({self.child})+"
@ -2298,11 +2302,16 @@ class RePlus(Re):
class ReStar(Re):
child: Re
def to_nfa(self, start: NFAState) -> NFAState:
end = self.child.to_nfa(start)
end.epsilons.append(start)
start.epsilons.append(end)
return end
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
start = NFAState()
child_start, ends = self.child.to_nfa()
start.epsilons.append(child_start)
for end in ends:
end.epsilons.append(start)
# TODO: Do I need to make an explicit end state here?
return (start, [start])
def __str__(self) -> str:
return f"({self.child})*"
@ -2312,10 +2321,14 @@ class ReStar(Re):
class ReQuestion(Re):
child: Re
def to_nfa(self, start: NFAState) -> NFAState:
end = self.child.to_nfa(start)
start.epsilons.append(end)
return end
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
start = NFAState()
child_start, ends = self.child.to_nfa()
start.epsilons.append(child_start)
ends.append(start)
return (start, ends)
def __str__(self) -> str:
return f"({self.child})?"
@ -2326,9 +2339,12 @@ class ReSeq(Re):
left: Re
right: Re
def to_nfa(self, start: NFAState) -> NFAState:
mid = self.left.to_nfa(start)
return self.right.to_nfa(mid)
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
left_start, left_ends = self.left.to_nfa()
right_start, right_ends = self.right.to_nfa()
for end in left_ends:
end.epsilons.append(right_start)
return (left_start, right_ends)
def __str__(self) -> str:
return f"{self.left}{self.right}"
@ -2339,20 +2355,15 @@ class ReAlt(Re):
left: Re
right: Re
def to_nfa(self, start: NFAState) -> NFAState:
left_start = NFAState()
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
left_start, left_ends = self.left.to_nfa()
right_start, right_ends = self.right.to_nfa()
start = NFAState()
start.epsilons.append(left_start)
left_end = self.left.to_nfa(left_start)
right_start = NFAState()
start.epsilons.append(right_start)
right_end = self.right.to_nfa(right_start)
end = NFAState()
left_end.epsilons.append(end)
right_end.epsilons.append(end)
return end
return (start, left_ends + right_ends)
def __str__(self) -> str:
return f"(({self.left})||({self.right}))"
@ -2438,22 +2449,23 @@ class NFASuperState:
return accept
def compile_lexer(x: Grammar) -> LexerTable:
def compile_terminals(terminals: typing.Iterable[Terminal]) -> LexerTable:
# Parse the terminals all together into a big NFA rooted at `NFA`.
NFA = NFAState()
for terminal in x.terminals:
start = NFAState()
NFA.epsilons.append(start)
for terminal in terminals:
pattern = terminal.pattern
if isinstance(pattern, Re):
ending = pattern.to_nfa(start)
else:
ending = start
for c in pattern:
ending = ending.add_edge(Span.from_str(c), NFAState())
start, ends = pattern.to_nfa()
for end in ends:
end.accept.append(terminal)
NFA.epsilons.append(start)
ending.accept.append(terminal)
else:
start = end = NFAState()
for c in pattern:
end = end.add_edge(Span.from_str(c), NFAState())
end.accept.append(terminal)
NFA.epsilons.append(start)
NFA.dump_graph()
@ -2482,6 +2494,10 @@ def compile_lexer(x: Grammar) -> LexerTable:
]
def compile_lexer(grammar: Grammar) -> LexerTable:
return compile_terminals(grammar.terminals)
def dump_lexer_table(table: LexerTable):
with open("lexer.dot", "w", encoding="utf-8") as f:
f.write("digraph G {\n")

View file

@ -381,3 +381,37 @@ def test_lexer_compile():
(LexTest.BLANKS, 5, 1),
(LexTest.IDENTIFIER, 6, 3),
]
def test_lexer_numbers():
class LexTest(Grammar):
@rule
def number(self):
return self.NUMBER
start = number
NUMBER = Terminal(
Re.seq(
Re.set(("0", "9")).plus(),
Re.seq(
Re.literal("."),
Re.set(("0", "9")).plus(),
Re.seq(
Re.set("e", "E"),
Re.set("+", "-").question(),
Re.set(("0", "9")).plus(),
).question(),
).question(),
)
)
lexer = compile_lexer(LexTest())
dump_lexer_table(lexer)
number_string = "1234.12"
tokens = list(generic_tokenize(number_string, lexer))
assert tokens == [
(LexTest.NUMBER, 0, len(number_string)),
]