Correct NFA construction

There was a bug in the way that I was converting regular expressions
to NFAs. I'm still not entirely sure what was going on, but I
re-visited the construction and made it follow the literature more
closely and it fixed the problem.
This commit is contained in:
John Doty 2024-08-24 09:24:29 -07:00
parent 30f7798719
commit 0c952e4905
2 changed files with 91 additions and 41 deletions

View file

@ -2169,8 +2169,7 @@ class NFAState:
@dataclasses.dataclass @dataclasses.dataclass
class Re: class Re:
def to_nfa(self, start: NFAState) -> NFAState: def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
del start
raise NotImplementedError() raise NotImplementedError()
def __str__(self) -> str: def __str__(self) -> str:
@ -2258,11 +2257,12 @@ class ReSet(Re):
def __invert__(self) -> "ReSet": def __invert__(self) -> "ReSet":
return self.invert() return self.invert()
def to_nfa(self, start: NFAState) -> NFAState: def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
start = NFAState()
end = NFAState() end = NFAState()
for span in self.values: for span in self.values:
start.add_edge(span, end) start.add_edge(span, end)
return end return (start, [end])
def __str__(self) -> str: def __str__(self) -> str:
if len(self.values) == 1: if len(self.values) == 1:
@ -2285,10 +2285,14 @@ class ReSet(Re):
class RePlus(Re): class RePlus(Re):
child: Re child: Re
def to_nfa(self, start: NFAState) -> NFAState: def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
end = self.child.to_nfa(start) start, ends = self.child.to_nfa()
end = NFAState()
for e in ends:
e.epsilons.append(end)
end.epsilons.append(start) end.epsilons.append(start)
return end return (start, [end])
def __str__(self) -> str: def __str__(self) -> str:
return f"({self.child})+" return f"({self.child})+"
@ -2298,11 +2302,16 @@ class RePlus(Re):
class ReStar(Re): class ReStar(Re):
child: Re child: Re
def to_nfa(self, start: NFAState) -> NFAState: def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
end = self.child.to_nfa(start) start = NFAState()
child_start, ends = self.child.to_nfa()
start.epsilons.append(child_start)
for end in ends:
end.epsilons.append(start) end.epsilons.append(start)
start.epsilons.append(end)
return end # TODO: Do I need to make an explicit end state here?
return (start, [start])
def __str__(self) -> str: def __str__(self) -> str:
return f"({self.child})*" return f"({self.child})*"
@ -2312,10 +2321,14 @@ class ReStar(Re):
class ReQuestion(Re): class ReQuestion(Re):
child: Re child: Re
def to_nfa(self, start: NFAState) -> NFAState: def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
end = self.child.to_nfa(start) start = NFAState()
start.epsilons.append(end)
return end child_start, ends = self.child.to_nfa()
start.epsilons.append(child_start)
ends.append(start)
return (start, ends)
def __str__(self) -> str: def __str__(self) -> str:
return f"({self.child})?" return f"({self.child})?"
@ -2326,9 +2339,12 @@ class ReSeq(Re):
left: Re left: Re
right: Re right: Re
def to_nfa(self, start: NFAState) -> NFAState: def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
mid = self.left.to_nfa(start) left_start, left_ends = self.left.to_nfa()
return self.right.to_nfa(mid) right_start, right_ends = self.right.to_nfa()
for end in left_ends:
end.epsilons.append(right_start)
return (left_start, right_ends)
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.left}{self.right}" return f"{self.left}{self.right}"
@ -2339,20 +2355,15 @@ class ReAlt(Re):
left: Re left: Re
right: Re right: Re
def to_nfa(self, start: NFAState) -> NFAState: def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
left_start = NFAState() left_start, left_ends = self.left.to_nfa()
right_start, right_ends = self.right.to_nfa()
start = NFAState()
start.epsilons.append(left_start) start.epsilons.append(left_start)
left_end = self.left.to_nfa(left_start)
right_start = NFAState()
start.epsilons.append(right_start) start.epsilons.append(right_start)
right_end = self.right.to_nfa(right_start)
end = NFAState() return (start, left_ends + right_ends)
left_end.epsilons.append(end)
right_end.epsilons.append(end)
return end
def __str__(self) -> str: def __str__(self) -> str:
return f"(({self.left})||({self.right}))" return f"(({self.left})||({self.right}))"
@ -2438,22 +2449,23 @@ class NFASuperState:
return accept return accept
def compile_lexer(x: Grammar) -> LexerTable: def compile_terminals(terminals: typing.Iterable[Terminal]) -> LexerTable:
# Parse the terminals all together into a big NFA rooted at `NFA`. # Parse the terminals all together into a big NFA rooted at `NFA`.
NFA = NFAState() NFA = NFAState()
for terminal in x.terminals: for terminal in terminals:
start = NFAState()
NFA.epsilons.append(start)
pattern = terminal.pattern pattern = terminal.pattern
if isinstance(pattern, Re): if isinstance(pattern, Re):
ending = pattern.to_nfa(start) start, ends = pattern.to_nfa()
else: for end in ends:
ending = start end.accept.append(terminal)
for c in pattern: NFA.epsilons.append(start)
ending = ending.add_edge(Span.from_str(c), NFAState())
ending.accept.append(terminal) else:
start = end = NFAState()
for c in pattern:
end = end.add_edge(Span.from_str(c), NFAState())
end.accept.append(terminal)
NFA.epsilons.append(start)
NFA.dump_graph() NFA.dump_graph()
@ -2482,6 +2494,10 @@ def compile_lexer(x: Grammar) -> LexerTable:
] ]
def compile_lexer(grammar: Grammar) -> LexerTable:
return compile_terminals(grammar.terminals)
def dump_lexer_table(table: LexerTable): def dump_lexer_table(table: LexerTable):
with open("lexer.dot", "w", encoding="utf-8") as f: with open("lexer.dot", "w", encoding="utf-8") as f:
f.write("digraph G {\n") f.write("digraph G {\n")

View file

@ -381,3 +381,37 @@ def test_lexer_compile():
(LexTest.BLANKS, 5, 1), (LexTest.BLANKS, 5, 1),
(LexTest.IDENTIFIER, 6, 3), (LexTest.IDENTIFIER, 6, 3),
] ]
def test_lexer_numbers():
class LexTest(Grammar):
@rule
def number(self):
return self.NUMBER
start = number
NUMBER = Terminal(
Re.seq(
Re.set(("0", "9")).plus(),
Re.seq(
Re.literal("."),
Re.set(("0", "9")).plus(),
Re.seq(
Re.set("e", "E"),
Re.set("+", "-").question(),
Re.set(("0", "9")).plus(),
).question(),
).question(),
)
)
lexer = compile_lexer(LexTest())
dump_lexer_table(lexer)
number_string = "1234.12"
tokens = list(generic_tokenize(number_string, lexer))
assert tokens == [
(LexTest.NUMBER, 0, len(number_string)),
]