Generated lexers actually kinda work

But regular expressions are underpowered and verbose
This commit is contained in:
John Doty 2024-08-23 15:32:35 -07:00
parent 58c3004702
commit 72052645d6
6 changed files with 957 additions and 544 deletions

View file

@ -131,13 +131,13 @@ May 2024
"""
import abc
import bisect
import collections
import dataclasses
import enum
import functools
import inspect
import json
import sys
import typing
@ -1607,18 +1607,19 @@ class Terminal(Rule):
"""A token, or terminal symbol in the grammar."""
value: str | None
pattern: str
regex: bool
pattern: "str | Re"
def __init__(self, pattern, name=None, regex=False):
def __init__(self, pattern, name=None):
self.value = name
self.pattern = pattern
self.regex = regex
def flatten(self) -> typing.Generator[list["str | Terminal"], None, None]:
# We are just ourselves when flattened.
yield [self]
def __repr__(self) -> str:
return self.value or "???"
class NonTerminal(Rule):
"""A non-terminal, or a production, in the grammar.
@ -1945,14 +1946,65 @@ class Span:
upper: int # exclusive
@classmethod
def from_str(cls, c: str) -> "Span":
return Span(lower=ord(c), upper=ord(c) + 1)
def from_str(cls, lower: str, upper: str | None = None) -> "Span":
lo = ord(lower)
if upper is None:
hi = lo + 1
else:
hi = ord(upper) + 1
return Span(lower=lo, upper=hi)
def __len__(self) -> int:
return self.upper - self.lower
def intersects(self, other: "Span") -> bool:
"""Determine if this span intersects the other span."""
return self.lower < other.upper and self.upper > other.lower
def split(self, other: "Span") -> tuple["Span|None", "Span", "Span|None"]:
assert self.intersects(other)
def split(self, other: "Span") -> tuple["Span|None", "Span|None", "Span|None"]:
"""Split two possibly-intersecting spans into three regions: a low
region, which covers just the lower part of the union, a mid region,
which covers the intersection, and a hi region, which covers just the
upper part of the union.
Together, low and high cover the union of the two spans. Mid covers
the intersection. The implication is that if both spans are identical
then the low and high regions will both be None and mid will be equal
to both.
Graphically, given two spans A and B:
[ B )
[ A )
[ lo )[ mid )[ hi )
If the lower bounds align then the `lo` region is empty:
[ B )
[ A )
[ mid )[ hi )
If the upper bounds align then the `hi` region is empty:
[ B )
[ A )
[ lo )[ mid )
If both bounds align then both are empty:
[ B )
[ A )
[ mid )
split is reflexive: it doesn't matter which order you split things in,
you will always get the same output spans, in the same order.
"""
if not self.intersects(other):
if self.lower < other.lower:
return (self, None, other)
else:
return (other, None, self)
first = min(self.lower, other.lower)
second = max(self.lower, other.lower)
@ -1966,23 +2018,14 @@ class Span:
return (low, mid, hi)
def __str__(self) -> str:
if self.upper - self.lower == 1:
return str(self.lower)
lower = str(self.lower)
upper = str(self.upper)
return f"[{lower}-{upper})"
def __lt__(self, other: "Span") -> bool:
return self.lower < other.lower
return f"[{self.lower}-{self.upper})"
ET = typing.TypeVar("ET")
class EdgeList[ET]:
"""A list of edge transitions, keyed by *span*. A given span can have
multiple targets, because this supports NFAs."""
"""A list of edge transitions, keyed by *span*."""
_edges: list[tuple[Span, list[ET]]]
@ -2000,80 +2043,415 @@ class EdgeList[ET]:
spans that overlap this one, split and generating multiple distinct
edges.
"""
# print(f" Adding {c}->{s} to {self}...")
# Look to see where we would put this span based solely on a
# sort of lower bounds.
point = bisect.bisect_left(self._edges, c, key=lambda x: x[0])
our_targets = [s]
# If this is not the first span in the list then we might
# overlap with the span to our left....
if point > 0:
left_point = point - 1
left_span, left_targets = self._edges[left_point]
if c.intersects(left_span):
# ...if we intersect with the span to our left then we
# must split the span to our left with regards to our
# span. Then we have three target spans:
#
# - The lo one, which just has the targets from the old
# left span. (This may be empty if we overlap the
# left one completely on the left side.)
#
# - The mid one, which has both the targets from the
# old left and the new target.
#
# - The hi one, which if it exists only has our target.
# If it exists it basically replaces the current span
# for our future processing. (If not, then our span
# is completely subsumed into the left span and we
# can stop.)
#
del self._edges[left_point]
lo, mid, hi = c.split(left_span)
# print(f" <- {c} splits {left_span} -> {lo}, {mid}, {hi} @{left_point}")
self._edges.insert(left_point, (mid, left_targets + [s]))
if lo is not None:
self._edges.insert(left_point, (lo, left_targets))
if hi is None or not hi.intersects(c):
# Yup, completely subsumed.
# print(f" result: {self} (left out)")
return
# Look to see where we would put this span based solely on a sort of
# lower bounds: find the lowest upper bound that is greater than the
# lower bound of the incoming span.
point = bisect.bisect_right(self._edges, c.lower, key=lambda x: x[0].upper)
# Continue processing with `c` as the hi split from the
# left. If the left and right spans abut each other then
# `c` will be subsumed in our right span.
c = hi
# We might need to run this in multiple iterations because we keep
# splitting against the *lowest* matching span.
next_span: Span | None = c
while next_span is not None:
c = next_span
next_span = None
# If point is not at the very end of the list then it might
# overlap the span to our right...
if point < len(self._edges):
# print(f" incoming: {self} @ {point} <- {c}->[{s}]")
# Check to see if we've run off the end of the list of spans.
if point == len(self._edges):
self._edges.insert(point, (c, [s]))
# print(f" trivial end: {self}")
return
# Nope, pull out the span to the right of us.
right_span, right_targets = self._edges[point]
if c.intersects(right_span):
# ...this is similar to the left case, above, except the
# lower bound has the targets that our only ours, etc.
del self._edges[point]
lo, mid, hi = c.split(right_span)
# print(f" -> {c} splits {right_span} -> {lo}, {mid}, {hi} @{point}")
if hi is not None:
# Because we intersect at least a little bit we know that we need to
# split and keep processing.
del self._edges[point]
lo, mid, hi = c.split(right_span) # Remember the semantics
# print(f" -> {c} splits {right_span} -> {lo}, {mid}, {hi} @{point}")
# We do this from lo to hi, lo first.
if lo is not None:
# NOTE: lo will never intersect both no matter what.
if lo.intersects(right_span):
assert not lo.intersects(c)
targets = right_targets
else:
assert lo.intersects(c)
targets = our_targets
self._edges.insert(point, (lo, targets))
point += 1 # Adjust the insertion point, important for us to keep running.
if mid is not None:
# If mid exists it is known to intersect with both so we can just
# do it.
self._edges.insert(point, (mid, right_targets + our_targets))
point += 1 # Adjust the insertion point, important for us to keep running.
if hi is not None:
# NOTE: Just like lo, hi will never intersect both no matter what.
if hi.intersects(right_span):
# If hi intersects the right span then we're done, no
# need to keep running.
assert not hi.intersects(c)
self._edges.insert(point, (hi, right_targets))
self._edges.insert(point, (mid, right_targets + [s]))
if lo is None or not lo.intersects(c):
# Our span is completely subsumed on the lower side
# of the range; there is no lower side that just has
# our targets. Bail now.
# print(f" result: {self} (right out)")
return
# Continue processing with `c` as the lo split, since
# that's the one that has only the specified state as the
# target.
c = lo
else:
# BUT! If hi intersects the incoming span then what we
# need to do is to replace the incoming span with hi
# (having chopped off the lower part of the incoming
# span) and continue to execute with only the upper part
# of the incoming span.
#
# Why? Because the upper part of the incoming span might
# intersect *more* spans, in which case we need to keep
# splitting and merging targets.
assert hi.intersects(c)
next_span = hi
# If we made it here then either we have a point that does not
# intersect at all, or it only partially intersects on either the
# left or right. Either way, we have ensured that:
#
# - c doesn't intersect with left or right (any more)
# - point is where it should go
self._edges.insert(point, (c, [s]))
# print(f" result: {self} (done)")
# print(f" result: {self}")
class NFAState:
"""An NFA state. Each state can be the accept state, with one or more
Terminals as the result."""
accept: list[Terminal]
epsilons: list["NFAState"]
_edges: EdgeList["NFAState"]
def __init__(self):
self.accept = []
self.epsilons = []
self._edges = EdgeList()
def __repr__(self):
return f"State{id(self)}"
def edges(self) -> typing.Iterable[tuple[Span, list["NFAState"]]]:
return self._edges
def add_edge(self, c: Span, s: "NFAState") -> "NFAState":
self._edges.add_edge(c, s)
return s
def dump_graph(self, name="nfa.dot"):
with open(name, "w", encoding="utf8") as f:
f.write("digraph G {\n")
stack: list[NFAState] = [self]
visited = set()
while len(stack) > 0:
state = stack.pop()
if state in visited:
continue
visited.add(state)
label = ", ".join([t.value for t in state.accept if t.value is not None])
f.write(f' {id(state)} [label="{label}"];\n')
for target in state.epsilons:
stack.append(target)
f.write(f' {id(state)} -> {id(target)} [label="\u03B5"];\n')
for span, targets in state.edges():
label = str(span).replace('"', '\\"')
for target in targets:
stack.append(target)
f.write(f' {id(state)} -> {id(target)} [label="{label}"];\n')
f.write("}\n")
@dataclasses.dataclass
class Re:
def to_nfa(self, start: NFAState) -> NFAState:
del start
raise NotImplementedError()
def __str__(self) -> str:
raise NotImplementedError()
@classmethod
def seq(cls, *values: "Re") -> "Re":
result = values[0]
for v in values[1:]:
result = RegexSequence(result, v)
return result
@classmethod
def literal(cls, value: str) -> "Re":
return cls.seq(*[RegexLiteral.from_ranges(c) for c in value])
@classmethod
def set(cls, *args: str | tuple[str, str]) -> "Re":
return RegexLiteral.from_ranges(*args)
def plus(self) -> "Re":
return RegexPlus(self)
def star(self) -> "Re":
return RegexStar(self)
def question(self) -> "Re":
return RegexQuestion(self)
def __or__(self, value: "Re", /) -> "Re":
return RegexAlternation(self, value)
@dataclasses.dataclass
class RegexLiteral(Re):
values: list[Span]
@classmethod
def from_ranges(cls, *args: str | tuple[str, str]) -> "RegexLiteral":
values = []
for a in args:
if isinstance(a, str):
values.append(Span.from_str(a))
else:
values.append(Span.from_str(a[0], a[1]))
return RegexLiteral(values)
def to_nfa(self, start: NFAState) -> NFAState:
end = NFAState()
for span in self.values:
start.add_edge(span, end)
return end
def __str__(self) -> str:
if len(self.values) == 1:
span = self.values[0]
if len(span) == 1:
return chr(span.lower)
ranges = []
for span in self.values:
start = chr(span.lower)
end = chr(span.upper - 1)
if start == end:
ranges.append(start)
else:
ranges.append(f"{start}-{end}")
return "[{}]".format("".join(ranges))
@dataclasses.dataclass
class RegexPlus(Re):
child: Re
def to_nfa(self, start: NFAState) -> NFAState:
end = self.child.to_nfa(start)
end.epsilons.append(start)
return end
def __str__(self) -> str:
return f"({self.child})+"
@dataclasses.dataclass
class RegexStar(Re):
child: Re
def to_nfa(self, start: NFAState) -> NFAState:
end = self.child.to_nfa(start)
end.epsilons.append(start)
start.epsilons.append(end)
return end
def __str__(self) -> str:
return f"({self.child})*"
@dataclasses.dataclass
class RegexQuestion(Re):
child: Re
def to_nfa(self, start: NFAState) -> NFAState:
end = self.child.to_nfa(start)
start.epsilons.append(end)
return end
def __str__(self) -> str:
return f"({self.child})?"
@dataclasses.dataclass
class RegexSequence(Re):
left: Re
right: Re
def to_nfa(self, start: NFAState) -> NFAState:
mid = self.left.to_nfa(start)
return self.right.to_nfa(mid)
def __str__(self) -> str:
return f"{self.left}{self.right}"
@dataclasses.dataclass
class RegexAlternation(Re):
left: Re
right: Re
def to_nfa(self, start: NFAState) -> NFAState:
left_start = NFAState()
start.epsilons.append(left_start)
left_end = self.left.to_nfa(left_start)
right_start = NFAState()
start.epsilons.append(right_start)
right_end = self.right.to_nfa(right_start)
end = NFAState()
left_end.epsilons.append(end)
right_end.epsilons.append(end)
return end
def __str__(self) -> str:
return f"(({self.left})||({self.right}))"
LexerTable = list[tuple[Terminal | None, list[tuple[Span, int]]]]
class NFASuperState:
states: frozenset[NFAState]
def __init__(self, states: typing.Iterable[NFAState]):
# Close over the given states, including every state that is
# reachable by epsilon-transition.
stack = list(states)
result = set()
while len(stack) > 0:
st = stack.pop()
if st in result:
continue
result.add(st)
stack.extend(st.epsilons)
self.states = frozenset(result)
def __eq__(self, other):
if not isinstance(other, NFASuperState):
return False
return self.states == other.states
def __hash__(self) -> int:
return hash(self.states)
def edges(self) -> list[tuple[Span, "NFASuperState"]]:
working: EdgeList[list[NFAState]] = EdgeList()
for st in self.states:
for span, targets in st.edges():
working.add_edge(span, targets)
# EdgeList maps span to list[list[State]] which we want to flatten.
last_upper = None
result = []
for span, stateses in working:
if last_upper is not None:
assert last_upper <= span.lower
last_upper = span.upper
s: list[NFAState] = []
for states in stateses:
s.extend(states)
result.append((span, NFASuperState(s)))
if len(result) > 0:
for i in range(0, len(result) - 1):
span = result[i][0]
next_span = result[i + 1][0]
assert span.upper <= next_span.lower
# TODO: Merge spans that are adjacent and go to the same state.
return result
def accept_terminal(self) -> Terminal | None:
accept = None
for st in self.states:
for ac in st.accept:
if accept is None:
accept = ac
elif accept.value != ac.value:
accept_regex = isinstance(accept.pattern, Re)
ac_regex = isinstance(ac.pattern, Re)
if accept_regex and not ac_regex:
accept = ac
elif ac_regex and not accept_regex:
pass
else:
raise ValueError(
f"Lexer is ambiguous: cannot distinguish between {accept.value} ('{accept.pattern}') and {ac.value} ('{ac.pattern}')"
)
return accept
def compile_lexer(x: Grammar) -> LexerTable:
# Parse the terminals all together into a big NFA rooted at `NFA`.
NFA = NFAState()
for terminal in x.terminals:
start = NFAState()
NFA.epsilons.append(start)
pattern = terminal.pattern
if isinstance(pattern, Re):
ending = pattern.to_nfa(start)
else:
ending = start
for c in pattern:
ending = ending.add_edge(Span.from_str(c), NFAState())
ending.accept.append(terminal)
NFA.dump_graph()
# Convert the NFA into a DFA in the most straightforward way (by tracking
# sets of state closures, called SuperStates.)
DFA: dict[NFASuperState, tuple[int, list[tuple[Span, NFASuperState]]]] = {}
stack = [NFASuperState([NFA])]
while len(stack) > 0:
ss = stack.pop()
if ss in DFA:
continue
edges = ss.edges()
DFA[ss] = (len(DFA), edges)
for _, target in edges:
stack.append(target)
return [
(
ss.accept_terminal(),
[(k, DFA[v][0]) for k, v in edges],
)
for ss, (_, edges) in DFA.items()
]
def dump_lexer_table(table: LexerTable):
with open("lexer.dot", "w", encoding="utf-8") as f:
f.write("digraph G {\n")
for index, (accept, edges) in enumerate(table):
label = accept.value if accept is not None else ""
f.write(f' {index} [label="{label}"];\n')
for span, target in edges:
label = str(span).replace('"', '\\"')
f.write(f' {index} -> {target} [label="{label}"];\n')
pass
f.write("}\n")