diff --git a/grammar.py b/grammar.py index 502c924..69f4de8 100644 --- a/grammar.py +++ b/grammar.py @@ -2,7 +2,17 @@ import re import typing -from parser import Assoc, Grammar, Nothing, rule, seq, Rule, Terminal +from parser import ( + Assoc, + Grammar, + Nothing, + rule, + seq, + Rule, + Terminal, + Re, +) +from parser.parser import compile_lexer, dump_lexer_table class FineGrammar(Grammar): @@ -321,7 +331,7 @@ class FineGrammar(Grammar): def field_value(self) -> Rule: return self.IDENTIFIER | seq(self.IDENTIFIER, self.COLON, self.expression) - BLANK = Terminal("[ \t\r\n]+", regex=True) + BLANK = Terminal(Re.set(" ", "\t", "\r", "\n").plus()) ARROW = Terminal("->") AS = Terminal("as") @@ -332,7 +342,12 @@ class FineGrammar(Grammar): ELSE = Terminal("else") FOR = Terminal("for") FUN = Terminal("fun") - IDENTIFIER = Terminal("[A-Za-z_][A-Za-z0-9_]*", regex=True) + IDENTIFIER = Terminal( + Re.seq( + Re.set(("a", "z"), ("A", "Z"), "_"), + Re.set(("a", "z"), ("A", "Z"), ("0", "9"), "_").star(), + ) + ) IF = Terminal("if") IMPORT = Terminal("import") IN = Terminal("in") @@ -341,7 +356,7 @@ class FineGrammar(Grammar): RCURLY = Terminal("}") RETURN = Terminal("return") SEMICOLON = Terminal(";") - STRING = Terminal('""', regex=True) + STRING = Terminal('""') # TODO WHILE = Terminal("while") EQUAL = Terminal("=") LPAREN = Terminal("(") @@ -361,7 +376,7 @@ class FineGrammar(Grammar): MINUS = Terminal("-") STAR = Terminal("*") SLASH = Terminal("/") - NUMBER = Terminal("[0-9]+", regex=True) + NUMBER = Terminal(Re.set(("0", "9")).plus()) TRUE = Terminal("true") FALSE = Terminal("false") BANG = Terminal("!") @@ -378,7 +393,6 @@ class FineGrammar(Grammar): # DORKY LEXER # ----------------------------------------------------------------------------- import bisect -import dataclasses NUMBER_RE = re.compile("[0-9]+(\\.[0-9]*([eE][-+]?[0-9]+)?)?") @@ -559,17 +573,5 @@ if __name__ == "__main__": grammar = FineGrammar() grammar.build_table() - class LexTest(Grammar): - @rule - def foo(self): - return self.IS - - start = foo - - IS = Terminal("is") - AS = Terminal("as") - IDENTIFIER = Terminal("[a-z]+", regex=True) - # IDENTIFIER = Terminal("[A-Za-z_][A-Za-z0-9_]*", regex=True) - - lexer = compile_lexer(LexTest()) + lexer = compile_lexer(grammar) dump_lexer_table(lexer) diff --git a/parser/parser.py b/parser/parser.py index 4d19e29..8a23d4e 100644 --- a/parser/parser.py +++ b/parser/parser.py @@ -131,13 +131,13 @@ May 2024 """ import abc +import bisect import collections import dataclasses import enum import functools import inspect import json -import sys import typing @@ -1607,18 +1607,19 @@ class Terminal(Rule): """A token, or terminal symbol in the grammar.""" value: str | None - pattern: str - regex: bool + pattern: "str | Re" - def __init__(self, pattern, name=None, regex=False): + def __init__(self, pattern, name=None): self.value = name self.pattern = pattern - self.regex = regex def flatten(self) -> typing.Generator[list["str | Terminal"], None, None]: # We are just ourselves when flattened. yield [self] + def __repr__(self) -> str: + return self.value or "???" + class NonTerminal(Rule): """A non-terminal, or a production, in the grammar. @@ -1945,14 +1946,65 @@ class Span: upper: int # exclusive @classmethod - def from_str(cls, c: str) -> "Span": - return Span(lower=ord(c), upper=ord(c) + 1) + def from_str(cls, lower: str, upper: str | None = None) -> "Span": + lo = ord(lower) + if upper is None: + hi = lo + 1 + else: + hi = ord(upper) + 1 + + return Span(lower=lo, upper=hi) + + def __len__(self) -> int: + return self.upper - self.lower def intersects(self, other: "Span") -> bool: + """Determine if this span intersects the other span.""" return self.lower < other.upper and self.upper > other.lower - def split(self, other: "Span") -> tuple["Span|None", "Span", "Span|None"]: - assert self.intersects(other) + def split(self, other: "Span") -> tuple["Span|None", "Span|None", "Span|None"]: + """Split two possibly-intersecting spans into three regions: a low + region, which covers just the lower part of the union, a mid region, + which covers the intersection, and a hi region, which covers just the + upper part of the union. + + Together, low and high cover the union of the two spans. Mid covers + the intersection. The implication is that if both spans are identical + then the low and high regions will both be None and mid will be equal + to both. + + Graphically, given two spans A and B: + + [ B ) + [ A ) + [ lo )[ mid )[ hi ) + + If the lower bounds align then the `lo` region is empty: + + [ B ) + [ A ) + [ mid )[ hi ) + + If the upper bounds align then the `hi` region is empty: + + [ B ) + [ A ) + [ lo )[ mid ) + + If both bounds align then both are empty: + + [ B ) + [ A ) + [ mid ) + + split is reflexive: it doesn't matter which order you split things in, + you will always get the same output spans, in the same order. + """ + if not self.intersects(other): + if self.lower < other.lower: + return (self, None, other) + else: + return (other, None, self) first = min(self.lower, other.lower) second = max(self.lower, other.lower) @@ -1966,23 +2018,14 @@ class Span: return (low, mid, hi) def __str__(self) -> str: - if self.upper - self.lower == 1: - return str(self.lower) - - lower = str(self.lower) - upper = str(self.upper) - return f"[{lower}-{upper})" - - def __lt__(self, other: "Span") -> bool: - return self.lower < other.lower + return f"[{self.lower}-{self.upper})" ET = typing.TypeVar("ET") class EdgeList[ET]: - """A list of edge transitions, keyed by *span*. A given span can have - multiple targets, because this supports NFAs.""" + """A list of edge transitions, keyed by *span*.""" _edges: list[tuple[Span, list[ET]]] @@ -2000,80 +2043,415 @@ class EdgeList[ET]: spans that overlap this one, split and generating multiple distinct edges. """ - # print(f" Adding {c}->{s} to {self}...") - # Look to see where we would put this span based solely on a - # sort of lower bounds. - point = bisect.bisect_left(self._edges, c, key=lambda x: x[0]) + our_targets = [s] - # If this is not the first span in the list then we might - # overlap with the span to our left.... - if point > 0: - left_point = point - 1 - left_span, left_targets = self._edges[left_point] - if c.intersects(left_span): - # ...if we intersect with the span to our left then we - # must split the span to our left with regards to our - # span. Then we have three target spans: - # - # - The lo one, which just has the targets from the old - # left span. (This may be empty if we overlap the - # left one completely on the left side.) - # - # - The mid one, which has both the targets from the - # old left and the new target. - # - # - The hi one, which if it exists only has our target. - # If it exists it basically replaces the current span - # for our future processing. (If not, then our span - # is completely subsumed into the left span and we - # can stop.) - # - del self._edges[left_point] - lo, mid, hi = c.split(left_span) - # print(f" <- {c} splits {left_span} -> {lo}, {mid}, {hi} @{left_point}") - self._edges.insert(left_point, (mid, left_targets + [s])) - if lo is not None: - self._edges.insert(left_point, (lo, left_targets)) - if hi is None or not hi.intersects(c): - # Yup, completely subsumed. - # print(f" result: {self} (left out)") - return + # Look to see where we would put this span based solely on a sort of + # lower bounds: find the lowest upper bound that is greater than the + # lower bound of the incoming span. + point = bisect.bisect_right(self._edges, c.lower, key=lambda x: x[0].upper) - # Continue processing with `c` as the hi split from the - # left. If the left and right spans abut each other then - # `c` will be subsumed in our right span. - c = hi + # We might need to run this in multiple iterations because we keep + # splitting against the *lowest* matching span. + next_span: Span | None = c + while next_span is not None: + c = next_span + next_span = None - # If point is not at the very end of the list then it might - # overlap the span to our right... - if point < len(self._edges): + # print(f" incoming: {self} @ {point} <- {c}->[{s}]") + + # Check to see if we've run off the end of the list of spans. + if point == len(self._edges): + self._edges.insert(point, (c, [s])) + # print(f" trivial end: {self}") + return + + # Nope, pull out the span to the right of us. right_span, right_targets = self._edges[point] - if c.intersects(right_span): - # ...this is similar to the left case, above, except the - # lower bound has the targets that our only ours, etc. - del self._edges[point] - lo, mid, hi = c.split(right_span) - # print(f" -> {c} splits {right_span} -> {lo}, {mid}, {hi} @{point}") - if hi is not None: + + # Because we intersect at least a little bit we know that we need to + # split and keep processing. + del self._edges[point] + lo, mid, hi = c.split(right_span) # Remember the semantics + # print(f" -> {c} splits {right_span} -> {lo}, {mid}, {hi} @{point}") + + # We do this from lo to hi, lo first. + if lo is not None: + # NOTE: lo will never intersect both no matter what. + if lo.intersects(right_span): + assert not lo.intersects(c) + targets = right_targets + else: + assert lo.intersects(c) + targets = our_targets + + self._edges.insert(point, (lo, targets)) + point += 1 # Adjust the insertion point, important for us to keep running. + + if mid is not None: + # If mid exists it is known to intersect with both so we can just + # do it. + self._edges.insert(point, (mid, right_targets + our_targets)) + point += 1 # Adjust the insertion point, important for us to keep running. + + if hi is not None: + # NOTE: Just like lo, hi will never intersect both no matter what. + if hi.intersects(right_span): + # If hi intersects the right span then we're done, no + # need to keep running. + assert not hi.intersects(c) self._edges.insert(point, (hi, right_targets)) - self._edges.insert(point, (mid, right_targets + [s])) - if lo is None or not lo.intersects(c): - # Our span is completely subsumed on the lower side - # of the range; there is no lower side that just has - # our targets. Bail now. - # print(f" result: {self} (right out)") - return - # Continue processing with `c` as the lo split, since - # that's the one that has only the specified state as the - # target. - c = lo + else: + # BUT! If hi intersects the incoming span then what we + # need to do is to replace the incoming span with hi + # (having chopped off the lower part of the incoming + # span) and continue to execute with only the upper part + # of the incoming span. + # + # Why? Because the upper part of the incoming span might + # intersect *more* spans, in which case we need to keep + # splitting and merging targets. + assert hi.intersects(c) + next_span = hi - # If we made it here then either we have a point that does not - # intersect at all, or it only partially intersects on either the - # left or right. Either way, we have ensured that: - # - # - c doesn't intersect with left or right (any more) - # - point is where it should go - self._edges.insert(point, (c, [s])) - # print(f" result: {self} (done)") + # print(f" result: {self}") + + +class NFAState: + """An NFA state. Each state can be the accept state, with one or more + Terminals as the result.""" + + accept: list[Terminal] + epsilons: list["NFAState"] + _edges: EdgeList["NFAState"] + + def __init__(self): + self.accept = [] + self.epsilons = [] + self._edges = EdgeList() + + def __repr__(self): + return f"State{id(self)}" + + def edges(self) -> typing.Iterable[tuple[Span, list["NFAState"]]]: + return self._edges + + def add_edge(self, c: Span, s: "NFAState") -> "NFAState": + self._edges.add_edge(c, s) + return s + + def dump_graph(self, name="nfa.dot"): + with open(name, "w", encoding="utf8") as f: + f.write("digraph G {\n") + + stack: list[NFAState] = [self] + visited = set() + while len(stack) > 0: + state = stack.pop() + if state in visited: + continue + visited.add(state) + + label = ", ".join([t.value for t in state.accept if t.value is not None]) + f.write(f' {id(state)} [label="{label}"];\n') + for target in state.epsilons: + stack.append(target) + f.write(f' {id(state)} -> {id(target)} [label="\u03B5"];\n') + + for span, targets in state.edges(): + label = str(span).replace('"', '\\"') + for target in targets: + stack.append(target) + f.write(f' {id(state)} -> {id(target)} [label="{label}"];\n') + + f.write("}\n") + + +@dataclasses.dataclass +class Re: + def to_nfa(self, start: NFAState) -> NFAState: + del start + raise NotImplementedError() + + def __str__(self) -> str: + raise NotImplementedError() + + @classmethod + def seq(cls, *values: "Re") -> "Re": + result = values[0] + for v in values[1:]: + result = RegexSequence(result, v) + return result + + @classmethod + def literal(cls, value: str) -> "Re": + return cls.seq(*[RegexLiteral.from_ranges(c) for c in value]) + + @classmethod + def set(cls, *args: str | tuple[str, str]) -> "Re": + return RegexLiteral.from_ranges(*args) + + def plus(self) -> "Re": + return RegexPlus(self) + + def star(self) -> "Re": + return RegexStar(self) + + def question(self) -> "Re": + return RegexQuestion(self) + + def __or__(self, value: "Re", /) -> "Re": + return RegexAlternation(self, value) + + +@dataclasses.dataclass +class RegexLiteral(Re): + values: list[Span] + + @classmethod + def from_ranges(cls, *args: str | tuple[str, str]) -> "RegexLiteral": + values = [] + for a in args: + if isinstance(a, str): + values.append(Span.from_str(a)) + else: + values.append(Span.from_str(a[0], a[1])) + + return RegexLiteral(values) + + def to_nfa(self, start: NFAState) -> NFAState: + end = NFAState() + for span in self.values: + start.add_edge(span, end) + return end + + def __str__(self) -> str: + if len(self.values) == 1: + span = self.values[0] + if len(span) == 1: + return chr(span.lower) + + ranges = [] + for span in self.values: + start = chr(span.lower) + end = chr(span.upper - 1) + if start == end: + ranges.append(start) + else: + ranges.append(f"{start}-{end}") + return "[{}]".format("".join(ranges)) + + +@dataclasses.dataclass +class RegexPlus(Re): + child: Re + + def to_nfa(self, start: NFAState) -> NFAState: + end = self.child.to_nfa(start) + end.epsilons.append(start) + return end + + def __str__(self) -> str: + return f"({self.child})+" + + +@dataclasses.dataclass +class RegexStar(Re): + child: Re + + def to_nfa(self, start: NFAState) -> NFAState: + end = self.child.to_nfa(start) + end.epsilons.append(start) + start.epsilons.append(end) + return end + + def __str__(self) -> str: + return f"({self.child})*" + + +@dataclasses.dataclass +class RegexQuestion(Re): + child: Re + + def to_nfa(self, start: NFAState) -> NFAState: + end = self.child.to_nfa(start) + start.epsilons.append(end) + return end + + def __str__(self) -> str: + return f"({self.child})?" + + +@dataclasses.dataclass +class RegexSequence(Re): + left: Re + right: Re + + def to_nfa(self, start: NFAState) -> NFAState: + mid = self.left.to_nfa(start) + return self.right.to_nfa(mid) + + def __str__(self) -> str: + return f"{self.left}{self.right}" + + +@dataclasses.dataclass +class RegexAlternation(Re): + left: Re + right: Re + + def to_nfa(self, start: NFAState) -> NFAState: + left_start = NFAState() + start.epsilons.append(left_start) + left_end = self.left.to_nfa(left_start) + + right_start = NFAState() + start.epsilons.append(right_start) + right_end = self.right.to_nfa(right_start) + + end = NFAState() + left_end.epsilons.append(end) + right_end.epsilons.append(end) + + return end + + def __str__(self) -> str: + return f"(({self.left})||({self.right}))" + + +LexerTable = list[tuple[Terminal | None, list[tuple[Span, int]]]] + + +class NFASuperState: + states: frozenset[NFAState] + + def __init__(self, states: typing.Iterable[NFAState]): + # Close over the given states, including every state that is + # reachable by epsilon-transition. + stack = list(states) + result = set() + while len(stack) > 0: + st = stack.pop() + if st in result: + continue + result.add(st) + stack.extend(st.epsilons) + + self.states = frozenset(result) + + def __eq__(self, other): + if not isinstance(other, NFASuperState): + return False + return self.states == other.states + + def __hash__(self) -> int: + return hash(self.states) + + def edges(self) -> list[tuple[Span, "NFASuperState"]]: + working: EdgeList[list[NFAState]] = EdgeList() + for st in self.states: + for span, targets in st.edges(): + working.add_edge(span, targets) + + # EdgeList maps span to list[list[State]] which we want to flatten. + last_upper = None + result = [] + for span, stateses in working: + if last_upper is not None: + assert last_upper <= span.lower + last_upper = span.upper + + s: list[NFAState] = [] + for states in stateses: + s.extend(states) + + result.append((span, NFASuperState(s))) + + if len(result) > 0: + for i in range(0, len(result) - 1): + span = result[i][0] + next_span = result[i + 1][0] + assert span.upper <= next_span.lower + + # TODO: Merge spans that are adjacent and go to the same state. + + return result + + def accept_terminal(self) -> Terminal | None: + accept = None + for st in self.states: + for ac in st.accept: + if accept is None: + accept = ac + elif accept.value != ac.value: + accept_regex = isinstance(accept.pattern, Re) + ac_regex = isinstance(ac.pattern, Re) + + if accept_regex and not ac_regex: + accept = ac + elif ac_regex and not accept_regex: + pass + else: + raise ValueError( + f"Lexer is ambiguous: cannot distinguish between {accept.value} ('{accept.pattern}') and {ac.value} ('{ac.pattern}')" + ) + + return accept + + +def compile_lexer(x: Grammar) -> LexerTable: + # Parse the terminals all together into a big NFA rooted at `NFA`. + NFA = NFAState() + for terminal in x.terminals: + start = NFAState() + NFA.epsilons.append(start) + + pattern = terminal.pattern + if isinstance(pattern, Re): + ending = pattern.to_nfa(start) + else: + ending = start + for c in pattern: + ending = ending.add_edge(Span.from_str(c), NFAState()) + + ending.accept.append(terminal) + + NFA.dump_graph() + + # Convert the NFA into a DFA in the most straightforward way (by tracking + # sets of state closures, called SuperStates.) + DFA: dict[NFASuperState, tuple[int, list[tuple[Span, NFASuperState]]]] = {} + + stack = [NFASuperState([NFA])] + while len(stack) > 0: + ss = stack.pop() + if ss in DFA: + continue + + edges = ss.edges() + + DFA[ss] = (len(DFA), edges) + for _, target in edges: + stack.append(target) + + return [ + ( + ss.accept_terminal(), + [(k, DFA[v][0]) for k, v in edges], + ) + for ss, (_, edges) in DFA.items() + ] + + +def dump_lexer_table(table: LexerTable): + with open("lexer.dot", "w", encoding="utf-8") as f: + f.write("digraph G {\n") + for index, (accept, edges) in enumerate(table): + label = accept.value if accept is not None else "" + f.write(f' {index} [label="{label}"];\n') + for span, target in edges: + label = str(span).replace('"', '\\"') + f.write(f' {index} -> {target} [label="{label}"];\n') + + pass + f.write("}\n") diff --git a/parser/runtime.py b/parser/runtime.py index f5be3a4..124bc7b 100644 --- a/parser/runtime.py +++ b/parser/runtime.py @@ -430,3 +430,58 @@ class Parser: error_strings.append(f"{line_index}:{column_index}: {parse_error.message}") return (result, error_strings) + + +def generic_tokenize( + src: str, table: parser.LexerTable +) -> typing.Iterable[tuple[parser.Terminal, int, int]]: + pos = 0 + state = 0 + start = 0 + last_accept = None + last_accept_pos = 0 + + print(f"LEXING: {src} ({len(src)})") + + while pos < len(src): + while state is not None: + accept, edges = table[state] + if accept is not None: + last_accept = accept + last_accept_pos = pos + + print(f" @ {pos} state: {state} ({accept})") + if pos >= len(src): + break + + char = ord(src[pos]) + print(f" -> char: {char} ({repr(src[pos])})") + + # Find the index of the span where the upper value is the tightest + # bound on the character. + state = None + index = bisect.bisect_right(edges, char, key=lambda x: x[0].upper) + print(f" -> {index}") + if index < len(edges): + span, target = edges[index] + print(f" -> {span}, {target}") + if char >= span.lower: + print(f" -> target: {target}") + state = target + pos += 1 + + else: + print(f" Nope (outside range)") + else: + print(f" Nope (at end)") + + if last_accept is None: + raise Exception(f"Token error at {pos}") + + yield (last_accept, start, last_accept_pos - start) + + print(f" Yield: {last_accept}, reset to {last_accept_pos}") + last_accept = None + pos = last_accept_pos + start = pos + state = 0 diff --git a/pdm.lock b/pdm.lock index b80bf6d..a937da9 100644 --- a/pdm.lock +++ b/pdm.lock @@ -3,9 +3,26 @@ [metadata] groups = ["default", "dev"] -strategy = ["cross_platform", "inherit_metadata"] -lock_version = "4.4.1" -content_hash = "sha256:143b06c001132ba589a47b2b3a498dd54f4840d95d216c794068089fcea48d4d" +strategy = ["inherit_metadata"] +lock_version = "4.5.0" +content_hash = "sha256:c4fec06f95402db1e9843df4a8a4a275273c6ec4f41f192f30d8a92ee52d15ea" + +[[metadata.targets]] +requires_python = ">=3.12" + +[[package]] +name = "attrs" +version = "24.2.0" +requires_python = ">=3.7" +summary = "Classes Without Boilerplate" +groups = ["dev"] +dependencies = [ + "importlib-metadata; python_version < \"3.8\"", +] +files = [ + {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, + {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, +] [[package]] name = "colorama" @@ -19,6 +36,22 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "hypothesis" +version = "6.111.1" +requires_python = ">=3.8" +summary = "A library for property-based testing" +groups = ["dev"] +dependencies = [ + "attrs>=22.2.0", + "exceptiongroup>=1.0.0; python_version < \"3.11\"", + "sortedcontainers<3.0.0,>=2.1.0", +] +files = [ + {file = "hypothesis-6.111.1-py3-none-any.whl", hash = "sha256:9422adbac4b2104f6cf92dc6604b5c9df975efc08ffc7145ecc39bc617243835"}, + {file = "hypothesis-6.111.1.tar.gz", hash = "sha256:6ab6185a858fa692bf125c0d0a936134edc318bee01c05e407c71c9ead0b61c5"}, +] + [[package]] name = "iniconfig" version = "2.0.0" @@ -60,11 +93,23 @@ summary = "pytest: simple powerful testing with Python" groups = ["dev"] dependencies = [ "colorama; sys_platform == \"win32\"", + "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"", "iniconfig", "packaging", "pluggy<2.0,>=1.5", + "tomli>=1; python_version < \"3.11\"", ] files = [ {file = "pytest-8.2.2-py3-none-any.whl", hash = "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343"}, {file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"}, ] + +[[package]] +name = "sortedcontainers" +version = "2.4.0" +summary = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set" +groups = ["dev"] +files = [ + {file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"}, + {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, +] diff --git a/pyproject.toml b/pyproject.toml index 1e28adc..c7721e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ distribution = true [tool.pdm.dev-dependencies] dev = [ "pytest>=8.2.2", + "hypothesis>=6.111.1", ] [tool.pyright] diff --git a/tests/test_lexer.py b/tests/test_lexer.py index b082889..fe442d8 100644 --- a/tests/test_lexer.py +++ b/tests/test_lexer.py @@ -1,439 +1,22 @@ -from parser import Span +import collections -# LexerTable = list[tuple[Terminal | None, list[tuple[Span, int]]]] +from hypothesis import assume, example, given +from hypothesis.strategies import integers, lists, tuples +import pytest -# def compile_lexer(x: Grammar) -> LexerTable: +from parser import ( + EdgeList, + Span, + Grammar, + rule, + Terminal, + compile_lexer, + dump_lexer_table, + Re, +) -# class State: -# """An NFA state. Each state can be the accept state, with one or more -# Terminals as the result.""" - -# accept: list[Terminal] -# epsilons: list["State"] -# _edges: EdgeList["State"] - -# def __init__(self): -# self.accept = [] -# self.epsilons = [] -# self._edges = EdgeList() - -# def __repr__(self): -# return f"State{id(self)}" - -# def edges(self) -> typing.Iterable[tuple[Span, list["State"]]]: -# return self._edges - -# def add_edge(self, c: Span, s: "State") -> "State": -# self._edges.add_edge(c, s) -# return s - -# def dump_graph(self, name="nfa.dot"): -# with open(name, "w", encoding="utf8") as f: -# f.write("digraph G {\n") - -# stack: list[State] = [self] -# visited = set() -# while len(stack) > 0: -# state = stack.pop() -# if state in visited: -# continue -# visited.add(state) - -# label = ", ".join([t.value for t in state.accept if t.value is not None]) -# f.write(f' {id(state)} [label="{label}"];\n') -# for target in state.epsilons: -# stack.append(target) -# f.write(f' {id(state)} -> {id(target)} [label="\u03B5"];\n') - -# for span, targets in state.edges(): -# label = str(span).replace('"', '\\"') -# for target in targets: -# stack.append(target) -# f.write(f' {id(state)} -> {id(target)} [label="{label}"];\n') - -# f.write("}\n") - -# @dataclasses.dataclass -# class RegexNode: -# def to_nfa(self, start: State) -> State: -# del start -# raise NotImplementedError() - -# def __str__(self) -> str: -# raise NotImplementedError() - -# @dataclasses.dataclass -# class RegexLiteral(RegexNode): -# values: list[tuple[str, str]] - -# def to_nfa(self, start: State) -> State: -# end = State() -# for s, e in self.values: -# start.add_edge(Span(ord(s), ord(e)), end) -# return end - -# def __str__(self) -> str: -# if len(self.values) == 1: -# start, end = self.values[0] -# if start == end: -# return start - -# ranges = [] -# for start, end in self.values: -# if start == end: -# ranges.append(start) -# else: -# ranges.append(f"{start}-{end}") -# return "![{}]".format("".join(ranges)) - -# @dataclasses.dataclass -# class RegexPlus(RegexNode): -# child: RegexNode - -# def to_nfa(self, start: State) -> State: -# end = self.child.to_nfa(start) -# end.epsilons.append(start) -# return end - -# def __str__(self) -> str: -# return f"({self.child})+" - -# @dataclasses.dataclass -# class RegexStar(RegexNode): -# child: RegexNode - -# def to_nfa(self, start: State) -> State: -# end = self.child.to_nfa(start) -# end.epsilons.append(start) -# start.epsilons.append(end) -# return end - -# def __str__(self) -> str: -# return f"({self.child})*" - -# @dataclasses.dataclass -# class RegexQuestion(RegexNode): -# child: RegexNode - -# def to_nfa(self, start: State) -> State: -# end = self.child.to_nfa(start) -# start.epsilons.append(end) -# return end - -# def __str__(self) -> str: -# return f"({self.child})?" - -# @dataclasses.dataclass -# class RegexSequence(RegexNode): -# left: RegexNode -# right: RegexNode - -# def to_nfa(self, start: State) -> State: -# mid = self.left.to_nfa(start) -# return self.right.to_nfa(mid) - -# def __str__(self) -> str: -# return f"{self.left}{self.right}" - -# @dataclasses.dataclass -# class RegexAlternation(RegexNode): -# left: RegexNode -# right: RegexNode - -# def to_nfa(self, start: State) -> State: -# left_start = State() -# start.epsilons.append(left_start) -# left_end = self.left.to_nfa(left_start) - -# right_start = State() -# start.epsilons.append(right_start) -# right_end = self.right.to_nfa(right_start) - -# end = State() -# left_end.epsilons.append(end) -# right_end.epsilons.append(end) - -# return end - -# def __str__(self) -> str: -# return f"(({self.left})||({self.right}))" - -# class RegexParser: -# # TODO: HANDLE ALTERNATION AND PRECEDENCE (CONCAT HAS HIGHEST PRECEDENCE) -# PREFIX: dict[str, typing.Callable[[str], RegexNode]] -# POSTFIX: dict[str, typing.Callable[[RegexNode, int], RegexNode]] -# BINDING: dict[str, tuple[int, int]] - -# index: int -# pattern: str - -# def __init__(self, pattern: str): -# self.PREFIX = { -# "(": self.parse_group, -# "[": self.parse_set, -# } -# self.POSTFIX = { -# "+": self.parse_plus, -# "*": self.parse_star, -# "?": self.parse_question, -# "|": self.parse_alternation, -# } - -# self.BINDING = { -# "|": (1, 1), -# "+": (2, 2), -# "*": (2, 2), -# "?": (2, 2), -# ")": (-1, -1), # Always stop parsing on ) -# } - -# self.index = 0 -# self.pattern = pattern - -# def consume(self) -> str: -# if self.index >= len(self.pattern): -# raise ValueError(f"Unable to parse regular expression '{self.pattern}'") -# result = self.pattern[self.index] -# self.index += 1 -# return result - -# def peek(self) -> str | None: -# if self.index >= len(self.pattern): -# return None -# return self.pattern[self.index] - -# def eof(self) -> bool: -# return self.index >= len(self.pattern) - -# def expect(self, ch: str): -# actual = self.consume() -# if ch != actual: -# raise ValueError(f"Expected '{ch}'") - -# def parse_regex(self, minimum_binding=0) -> RegexNode: -# ch = self.consume() -# parser = self.PREFIX.get(ch, self.parse_single) -# node = parser(ch) - -# while not self.eof(): -# ch = self.peek() -# assert ch is not None - -# lp, rp = self.BINDING.get(ch, (minimum_binding, minimum_binding)) -# if lp < minimum_binding: -# break - -# parser = self.POSTFIX.get(ch, self.parse_concat) -# node = parser(node, rp) - -# return node - -# def parse_single(self, ch: str) -> RegexNode: -# return RegexLiteral(values=[(ch, ch)]) - -# def parse_group(self, ch: str) -> RegexNode: -# del ch - -# node = self.parse_regex() -# self.expect(")") -# return node - -# def parse_set(self, ch: str) -> RegexNode: -# del ch - -# # TODO: INVERSION? -# ranges = [] -# while self.peek() not in (None, "]"): -# start = self.consume() -# if self.peek() == "-": -# self.consume() -# end = self.consume() -# else: -# end = start -# ranges.append((start, end)) - -# self.expect("]") -# return RegexLiteral(values=ranges) - -# def parse_alternation(self, node: RegexNode, rp: int) -> RegexNode: -# return RegexAlternation(left=node, right=self.parse_regex(rp)) - -# def parse_plus(self, left: RegexNode, rp: int) -> RegexNode: -# del rp -# self.expect("+") -# return RegexPlus(child=left) - -# def parse_star(self, left: RegexNode, rp: int) -> RegexNode: -# del rp -# self.expect("*") -# return RegexStar(child=left) - -# def parse_question(self, left: RegexNode, rp: int) -> RegexNode: -# del rp -# self.expect("?") -# return RegexQuestion(child=left) - -# def parse_concat(self, left: RegexNode, rp: int) -> RegexNode: -# return RegexSequence(left, self.parse_regex(rp)) - -# class SuperState: -# states: frozenset[State] -# index: int - -# def __init__(self, states: typing.Iterable[State]): -# # Close over the given states, including every state that is -# # reachable by epsilon-transition. -# stack = list(states) -# result = set() -# while len(stack) > 0: -# st = stack.pop() -# if st in result: -# continue -# result.add(st) -# stack.extend(st.epsilons) - -# self.states = frozenset(result) -# self.index = -1 - -# def __eq__(self, other): -# if not isinstance(other, SuperState): -# return False -# return self.states == other.states - -# def __hash__(self) -> int: -# return hash(self.states) - -# def edges(self) -> list[tuple[Span, "SuperState"]]: -# working: EdgeList[list[State]] = EdgeList() -# for st in self.states: -# for span, targets in st.edges(): -# working.add_edge(span, targets) - -# # EdgeList maps span to list[list[State]] which we want to flatten. -# result = [] -# for span, stateses in working: -# s: list[State] = [] -# for states in stateses: -# s.extend(states) - -# result.append((span, SuperState(s))) - -# return result - -# def accept_terminal(self) -> Terminal | None: -# accept = None -# for st in self.states: -# for ac in st.accept: -# if accept is None: -# accept = ac -# elif accept.value != ac.value: -# if accept.regex and not ac.regex: -# accept = ac -# elif ac.regex and not accept.regex: -# pass -# else: -# raise ValueError( -# f"Lexer is ambiguous: cannot distinguish between {accept.value} ('{accept.pattern}') and {ac.value} ('{ac.pattern}')" -# ) - -# return accept - -# # Parse the terminals all together into a big NFA rooted at `NFA`. -# NFA = State() -# for token in x.terminals: -# start = State() -# NFA.epsilons.append(start) - -# if token.regex: -# node = RegexParser(token.pattern).parse_regex() -# print(f" Parsed {token.pattern} to {node}") -# ending = node.to_nfa(start) - -# else: -# ending = start -# for c in token.pattern: -# ending = ending.add_edge(Span.from_str(c), State()) - -# ending.accept.append(token) - -# NFA.dump_graph() - -# # Convert the NFA into a DFA in the most straightforward way (by tracking -# # sets of state closures, called SuperStates.) -# DFA: dict[SuperState, list[tuple[Span, SuperState]]] = {} -# stack = [SuperState([NFA])] -# while len(stack) > 0: -# ss = stack.pop() -# if ss in DFA: -# continue - -# edges = ss.edges() - -# DFA[ss] = edges -# for _, target in edges: -# stack.append(target) - -# for i, k in enumerate(DFA): -# k.index = i - -# return [ -# ( -# ss.accept_terminal(), -# [(k, v.index) for k, v in edges], -# ) -# for ss, edges in DFA.items() -# ] - - -# def dump_lexer_table(table: LexerTable): -# with open("lexer.dot", "w", encoding="utf-8") as f: -# f.write("digraph G {\n") -# for index, (accept, edges) in enumerate(table): -# label = accept.value if accept is not None else "" -# f.write(f' {index} [label="{label}"];\n') -# for span, target in edges: -# label = str(span).replace('"', '\\"') -# f.write(f' {index} -> {target} [label="{label}"];\n') - -# pass -# f.write("}\n") - - -# def generic_tokenize(src: str, table: LexerTable): -# pos = 0 -# state = 0 -# start = 0 -# last_accept = None -# last_accept_pos = 0 - -# while pos < len(src): -# accept, edges = table[state] -# if accept is not None: -# last_accept = accept -# last_accept_pos = pos + 1 - -# char = ord(src[pos]) - -# # Find the index of the span where the upper value is the tightest -# # bound on the character. -# index = bisect.bisect_left(edges, char, key=lambda x: x[0].upper) -# # If the character is greater than or equal to the lower bound we -# # found then we have a hit, otherwise no. -# state = edges[index][1] if index < len(edges) and char >= edges[index][0].lower else None -# if state is None: -# if last_accept is None: -# raise Exception(f"Token error at {pos}") - -# yield (last_accept, start, last_accept_pos - start) - -# last_accept = None -# pos = last_accept_pos -# start = pos -# state = 0 - -# else: -# pos += 1 +from parser.runtime import generic_tokenize def test_span_intersection(): @@ -450,3 +33,352 @@ def test_span_intersection(): right = Span(*b) assert left.intersects(right) assert right.intersects(left) + + +def test_span_no_intersection(): + pairs = [ + ((1, 2), (3, 4)), + ] + + for a, b in pairs: + left = Span(*a) + right = Span(*b) + assert not left.intersects(right) + assert not right.intersects(left) + + +def test_span_split(): + TC = collections.namedtuple("TC", ["left", "right", "expected"]) + cases = [ + TC( + left=Span(1, 4), + right=Span(2, 3), + expected=(Span(1, 2), Span(2, 3), Span(3, 4)), + ), + TC( + left=Span(1, 4), + right=Span(1, 2), + expected=(None, Span(1, 2), Span(2, 4)), + ), + TC( + left=Span(1, 4), + right=Span(3, 4), + expected=(Span(1, 3), Span(3, 4), None), + ), + TC( + left=Span(1, 4), + right=Span(1, 4), + expected=(None, Span(1, 4), None), + ), + ] + + for left, right, expected in cases: + result = left.split(right) + assert result == expected + + result = right.split(left) + assert result == expected + + +@given(integers(), integers()) +def test_equal_span_mid_only(x, y): + """Splitting spans against themselves results in an empty lo and hi bound.""" + assume(x < y) + span = Span(x, y) + lo, mid, hi = span.split(span) + assert lo is None + assert hi is None + assert mid == span + + +three_distinct_points = lists( + integers(), + min_size=3, + max_size=3, + unique=True, +).map(sorted) + + +@given(three_distinct_points) +def test_span_low_align_lo_none(vals): + """Splitting spans with aligned lower bounds results in an empty lo bound.""" + # x y z + # [ a ) + # [ b ) + x, y, z = vals + + a = Span(x, y) + b = Span(x, z) + lo, _, _ = a.split(b) + + assert lo is None + + +@given(three_distinct_points) +def test_span_high_align_hi_none(vals): + """Splitting spans with aligned lower bounds results in an empty lo bound.""" + # x y z + # [ a ) + # [ b ) + x, y, z = vals + + a = Span(y, z) + b = Span(x, z) + _, _, hi = a.split(b) + + assert hi is None + + +four_distinct_points = lists( + integers(), + min_size=4, + max_size=4, + unique=True, +).map(sorted) + + +@given(four_distinct_points) +def test_span_split_overlapping_lo_left(vals): + """Splitting two overlapping spans results in lo overlapping left.""" + a, b, c, d = vals + + left = Span(a, c) + right = Span(b, d) + + lo, _, _ = left.split(right) + assert lo is not None + assert lo.intersects(left) + + +@given(four_distinct_points) +def test_span_split_overlapping_lo_not_right(vals): + """Splitting two overlapping spans results in lo NOT overlapping right.""" + a, b, c, d = vals + + left = Span(a, c) + right = Span(b, d) + + lo, _, _ = left.split(right) + assert lo is not None + assert not lo.intersects(right) + + +@given(four_distinct_points) +def test_span_split_overlapping_mid_left(vals): + """Splitting two overlapping spans results in mid overlapping left.""" + a, b, c, d = vals + + left = Span(a, c) + right = Span(b, d) + + _, mid, _ = left.split(right) + assert mid is not None + assert mid.intersects(left) + + +@given(four_distinct_points) +def test_span_split_overlapping_mid_right(vals): + """Splitting two overlapping spans results in mid overlapping right.""" + a, b, c, d = vals + + left = Span(a, c) + right = Span(b, d) + + _, mid, _ = left.split(right) + assert mid is not None + assert mid.intersects(right) + + +@given(four_distinct_points) +def test_span_split_overlapping_hi_right(vals): + """Splitting two overlapping spans results in hi overlapping right.""" + a, b, c, d = vals + + left = Span(a, c) + right = Span(b, d) + + _, _, hi = left.split(right) + assert hi is not None + assert hi.intersects(right) + + +@given(four_distinct_points) +def test_span_split_overlapping_hi_not_left(vals): + """Splitting two overlapping spans results in hi NOT overlapping left.""" + a, b, c, d = vals + + left = Span(a, c) + right = Span(b, d) + + _, _, hi = left.split(right) + assert hi is not None + assert not hi.intersects(left) + + +@given(four_distinct_points) +def test_span_split_embedded(vals): + """Splitting two spans where one overlaps the other.""" + a, b, c, d = vals + + outer = Span(a, d) + inner = Span(b, c) + + lo, mid, hi = outer.split(inner) + + assert lo is not None + assert mid is not None + assert hi is not None + + assert lo.intersects(outer) + assert not lo.intersects(inner) + + assert mid.intersects(outer) + assert mid.intersects(inner) + + assert hi.intersects(outer) + assert not hi.intersects(inner) + + +def test_edge_list_single(): + el: EdgeList[str] = EdgeList() + el.add_edge(Span(1, 4), "A") + + edges = list(el) + assert edges == [ + (Span(1, 4), ["A"]), + ] + + +def test_edge_list_fully_enclosed(): + el: EdgeList[str] = EdgeList() + el.add_edge(Span(1, 4), "A") + el.add_edge(Span(2, 3), "B") + + edges = list(el) + assert edges == [ + (Span(1, 2), ["A"]), + (Span(2, 3), ["A", "B"]), + (Span(3, 4), ["A"]), + ] + + +def test_edge_list_overlap(): + el: EdgeList[str] = EdgeList() + el.add_edge(Span(1, 4), "A") + el.add_edge(Span(2, 5), "B") + + edges = list(el) + assert edges == [ + (Span(1, 2), ["A"]), + (Span(2, 4), ["A", "B"]), + (Span(4, 5), ["B"]), + ] + + +def test_edge_list_no_overlap(): + el: EdgeList[str] = EdgeList() + el.add_edge(Span(1, 4), "A") + el.add_edge(Span(5, 8), "B") + + edges = list(el) + assert edges == [ + (Span(1, 4), ["A"]), + (Span(5, 8), ["B"]), + ] + + +def test_edge_list_no_overlap_ordered(): + el: EdgeList[str] = EdgeList() + el.add_edge(Span(5, 8), "B") + el.add_edge(Span(1, 4), "A") + + edges = list(el) + assert edges == [ + (Span(1, 4), ["A"]), + (Span(5, 8), ["B"]), + ] + + +def test_edge_list_overlap_span(): + el: EdgeList[str] = EdgeList() + el.add_edge(Span(1, 3), "A") + el.add_edge(Span(4, 6), "B") + el.add_edge(Span(2, 5), "C") + + edges = list(el) + assert edges == [ + (Span(1, 2), ["A"]), + (Span(2, 3), ["A", "C"]), + (Span(3, 4), ["C"]), + (Span(4, 5), ["B", "C"]), + (Span(5, 6), ["B"]), + ] + + +def test_edge_list_overlap_span_big(): + el: EdgeList[str] = EdgeList() + el.add_edge(Span(2, 3), "A") + el.add_edge(Span(4, 5), "B") + el.add_edge(Span(6, 7), "C") + el.add_edge(Span(1, 8), "D") + + edges = list(el) + assert edges == [ + (Span(1, 2), ["D"]), + (Span(2, 3), ["A", "D"]), + (Span(3, 4), ["D"]), + (Span(4, 5), ["B", "D"]), + (Span(5, 6), ["D"]), + (Span(6, 7), ["C", "D"]), + (Span(7, 8), ["D"]), + ] + + +@given(lists(lists(integers(), min_size=2, max_size=2, unique=True), min_size=1)) +@example(points=[[0, 1], [1, 2]]) +def test_edge_list_always_sorted(points: list[tuple[int, int]]): + # OK this is weird but stick with me. + el: EdgeList[str] = EdgeList() + for i, (a, b) in enumerate(points): + lower = min(a, b) + upper = max(a, b) + + span = Span(lower, upper) + + el.add_edge(span, str(i)) + + last_upper = None + for span, _ in el: + if last_upper is not None: + assert last_upper <= span.lower, "Edges from list are not sorted" + last_upper = span.upper + + +def test_lexer_compile(): + class LexTest(Grammar): + @rule + def foo(self): + return self.IS + + start = foo + + IS = Terminal("is") + AS = Terminal("as") + IDENTIFIER = Terminal( + Re.seq( + Re.set(("a", "z"), ("A", "Z"), "_"), + Re.set(("a", "z"), ("A", "Z"), ("0", "9"), "_").star(), + ) + ) + BLANKS = Terminal(Re.set("\r", "\n", "\t", " ").plus()) + + lexer = compile_lexer(LexTest()) + dump_lexer_table(lexer) + tokens = list(generic_tokenize("xy is ass", lexer)) + assert tokens == [ + (LexTest.IDENTIFIER, 0, 2), + (LexTest.BLANKS, 2, 1), + (LexTest.IS, 3, 2), + (LexTest.BLANKS, 5, 1), + (LexTest.IDENTIFIER, 6, 3), + ]