Correct NFA construction
There was a bug in the way that I was converting regular expressions to NFAs. I'm still not entirely sure what was going on, but I re-visited the construction and made it follow the literature more closely and it fixed the problem.
This commit is contained in:
parent
30f7798719
commit
0c952e4905
2 changed files with 91 additions and 41 deletions
|
|
@ -2169,8 +2169,7 @@ class NFAState:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class Re:
|
||||
def to_nfa(self, start: NFAState) -> NFAState:
|
||||
del start
|
||||
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
|
@ -2258,11 +2257,12 @@ class ReSet(Re):
|
|||
def __invert__(self) -> "ReSet":
|
||||
return self.invert()
|
||||
|
||||
def to_nfa(self, start: NFAState) -> NFAState:
|
||||
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
|
||||
start = NFAState()
|
||||
end = NFAState()
|
||||
for span in self.values:
|
||||
start.add_edge(span, end)
|
||||
return end
|
||||
return (start, [end])
|
||||
|
||||
def __str__(self) -> str:
|
||||
if len(self.values) == 1:
|
||||
|
|
@ -2285,10 +2285,14 @@ class ReSet(Re):
|
|||
class RePlus(Re):
|
||||
child: Re
|
||||
|
||||
def to_nfa(self, start: NFAState) -> NFAState:
|
||||
end = self.child.to_nfa(start)
|
||||
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
|
||||
start, ends = self.child.to_nfa()
|
||||
|
||||
end = NFAState()
|
||||
for e in ends:
|
||||
e.epsilons.append(end)
|
||||
end.epsilons.append(start)
|
||||
return end
|
||||
return (start, [end])
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"({self.child})+"
|
||||
|
|
@ -2298,11 +2302,16 @@ class RePlus(Re):
|
|||
class ReStar(Re):
|
||||
child: Re
|
||||
|
||||
def to_nfa(self, start: NFAState) -> NFAState:
|
||||
end = self.child.to_nfa(start)
|
||||
end.epsilons.append(start)
|
||||
start.epsilons.append(end)
|
||||
return end
|
||||
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
|
||||
start = NFAState()
|
||||
|
||||
child_start, ends = self.child.to_nfa()
|
||||
start.epsilons.append(child_start)
|
||||
for end in ends:
|
||||
end.epsilons.append(start)
|
||||
|
||||
# TODO: Do I need to make an explicit end state here?
|
||||
return (start, [start])
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"({self.child})*"
|
||||
|
|
@ -2312,10 +2321,14 @@ class ReStar(Re):
|
|||
class ReQuestion(Re):
|
||||
child: Re
|
||||
|
||||
def to_nfa(self, start: NFAState) -> NFAState:
|
||||
end = self.child.to_nfa(start)
|
||||
start.epsilons.append(end)
|
||||
return end
|
||||
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
|
||||
start = NFAState()
|
||||
|
||||
child_start, ends = self.child.to_nfa()
|
||||
start.epsilons.append(child_start)
|
||||
ends.append(start)
|
||||
|
||||
return (start, ends)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"({self.child})?"
|
||||
|
|
@ -2326,9 +2339,12 @@ class ReSeq(Re):
|
|||
left: Re
|
||||
right: Re
|
||||
|
||||
def to_nfa(self, start: NFAState) -> NFAState:
|
||||
mid = self.left.to_nfa(start)
|
||||
return self.right.to_nfa(mid)
|
||||
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
|
||||
left_start, left_ends = self.left.to_nfa()
|
||||
right_start, right_ends = self.right.to_nfa()
|
||||
for end in left_ends:
|
||||
end.epsilons.append(right_start)
|
||||
return (left_start, right_ends)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.left}{self.right}"
|
||||
|
|
@ -2339,20 +2355,15 @@ class ReAlt(Re):
|
|||
left: Re
|
||||
right: Re
|
||||
|
||||
def to_nfa(self, start: NFAState) -> NFAState:
|
||||
left_start = NFAState()
|
||||
def to_nfa(self) -> tuple[NFAState, list[NFAState]]:
|
||||
left_start, left_ends = self.left.to_nfa()
|
||||
right_start, right_ends = self.right.to_nfa()
|
||||
|
||||
start = NFAState()
|
||||
start.epsilons.append(left_start)
|
||||
left_end = self.left.to_nfa(left_start)
|
||||
|
||||
right_start = NFAState()
|
||||
start.epsilons.append(right_start)
|
||||
right_end = self.right.to_nfa(right_start)
|
||||
|
||||
end = NFAState()
|
||||
left_end.epsilons.append(end)
|
||||
right_end.epsilons.append(end)
|
||||
|
||||
return end
|
||||
return (start, left_ends + right_ends)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"(({self.left})||({self.right}))"
|
||||
|
|
@ -2438,22 +2449,23 @@ class NFASuperState:
|
|||
return accept
|
||||
|
||||
|
||||
def compile_lexer(x: Grammar) -> LexerTable:
|
||||
def compile_terminals(terminals: typing.Iterable[Terminal]) -> LexerTable:
|
||||
# Parse the terminals all together into a big NFA rooted at `NFA`.
|
||||
NFA = NFAState()
|
||||
for terminal in x.terminals:
|
||||
start = NFAState()
|
||||
NFA.epsilons.append(start)
|
||||
|
||||
for terminal in terminals:
|
||||
pattern = terminal.pattern
|
||||
if isinstance(pattern, Re):
|
||||
ending = pattern.to_nfa(start)
|
||||
else:
|
||||
ending = start
|
||||
for c in pattern:
|
||||
ending = ending.add_edge(Span.from_str(c), NFAState())
|
||||
start, ends = pattern.to_nfa()
|
||||
for end in ends:
|
||||
end.accept.append(terminal)
|
||||
NFA.epsilons.append(start)
|
||||
|
||||
ending.accept.append(terminal)
|
||||
else:
|
||||
start = end = NFAState()
|
||||
for c in pattern:
|
||||
end = end.add_edge(Span.from_str(c), NFAState())
|
||||
end.accept.append(terminal)
|
||||
NFA.epsilons.append(start)
|
||||
|
||||
NFA.dump_graph()
|
||||
|
||||
|
|
@ -2482,6 +2494,10 @@ def compile_lexer(x: Grammar) -> LexerTable:
|
|||
]
|
||||
|
||||
|
||||
def compile_lexer(grammar: Grammar) -> LexerTable:
|
||||
return compile_terminals(grammar.terminals)
|
||||
|
||||
|
||||
def dump_lexer_table(table: LexerTable):
|
||||
with open("lexer.dot", "w", encoding="utf-8") as f:
|
||||
f.write("digraph G {\n")
|
||||
|
|
|
|||
|
|
@ -381,3 +381,37 @@ def test_lexer_compile():
|
|||
(LexTest.BLANKS, 5, 1),
|
||||
(LexTest.IDENTIFIER, 6, 3),
|
||||
]
|
||||
|
||||
|
||||
def test_lexer_numbers():
|
||||
class LexTest(Grammar):
|
||||
@rule
|
||||
def number(self):
|
||||
return self.NUMBER
|
||||
|
||||
start = number
|
||||
|
||||
NUMBER = Terminal(
|
||||
Re.seq(
|
||||
Re.set(("0", "9")).plus(),
|
||||
Re.seq(
|
||||
Re.literal("."),
|
||||
Re.set(("0", "9")).plus(),
|
||||
Re.seq(
|
||||
Re.set("e", "E"),
|
||||
Re.set("+", "-").question(),
|
||||
Re.set(("0", "9")).plus(),
|
||||
).question(),
|
||||
).question(),
|
||||
)
|
||||
)
|
||||
|
||||
lexer = compile_lexer(LexTest())
|
||||
dump_lexer_table(lexer)
|
||||
|
||||
number_string = "1234.12"
|
||||
|
||||
tokens = list(generic_tokenize(number_string, lexer))
|
||||
assert tokens == [
|
||||
(LexTest.NUMBER, 0, len(number_string)),
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue