Transparent rules

Better parsing/action types
Good grief
This commit is contained in:
John Doty 2024-05-29 09:07:19 -07:00
parent 4f8aef3f89
commit 45a9303a27
3 changed files with 198 additions and 102 deletions

View file

@ -78,11 +78,11 @@ class FineGrammar(Grammar):
@rule @rule
def file(self): def file(self):
return self.file_statement_list return self._file_statement_list
@rule @rule
def file_statement_list(self): def _file_statement_list(self):
return self.file_statement | (self.file_statement_list + self.file_statement) return self.file_statement | (self._file_statement_list + self.file_statement)
@rule @rule
def file_statement(self): def file_statement(self):

View file

@ -1,4 +1,5 @@
import bisect import bisect
from dataclasses import dataclass
import enum import enum
import select import select
import sys import sys
@ -22,7 +23,13 @@ def trace_state(stack, input, input_index, action):
) )
def parse(table, tokens, trace=None): @dataclass
class Tree:
name: str | None
children: typing.Tuple["Tree | str", ...]
def parse(table: parser.ParseTable, tokens, trace=None) -> typing.Tuple[Tree | None, list[str]]:
"""Parse the input with the generated parsing table and return the """Parse the input with the generated parsing table and return the
concrete syntax tree. concrete syntax tree.
@ -36,7 +43,7 @@ def parse(table, tokens, trace=None):
This is not a *great* parser, it's really just a demo for what you can This is not a *great* parser, it's really just a demo for what you can
do with the table. do with the table.
""" """
input = [t.value for (t, _, _) in tokens.tokens] input: list[str] = [t.value for (t, _, _) in tokens.tokens]
assert "$" not in input assert "$" not in input
input = input + ["$"] input = input + ["$"]
@ -45,38 +52,50 @@ def parse(table, tokens, trace=None):
# Our stack is a stack of tuples, where the first entry is the state number # Our stack is a stack of tuples, where the first entry is the state number
# and the second entry is the 'value' that was generated when the state was # and the second entry is the 'value' that was generated when the state was
# pushed. # pushed.
stack: list[typing.Tuple[int, typing.Any]] = [(0, None)] stack: list[typing.Tuple[int, str | Tree | None]] = [(0, None)]
while True: while True:
current_state = stack[-1][0] current_state = stack[-1][0]
current_token = input[input_index] current_token = input[input_index]
action = table[current_state].get(current_token, ("error",)) action = table.states[current_state].get(current_token, parser.Error())
if trace: if trace:
trace(stack, input, input_index, action) trace(stack, input, input_index, action)
if action[0] == "accept": match action:
return (stack[-1][1], []) case parser.Accept():
result = stack[-1][1]
assert isinstance(result, Tree)
return (result, [])
elif action[0] == "reduce": case parser.Reduce(name=name, count=size, transparent=transparent):
name = action[1] children: list[str | Tree] = []
size = action[2] for _, c in stack[-size:]:
if c is None:
continue
elif isinstance(c, Tree) and c.name is None:
children.extend(c.children)
else:
children.append(c)
value = (name, tuple(s[1] for s in stack[-size:])) value = Tree(name=name if not transparent else None, children=tuple(children))
stack = stack[:-size] stack = stack[:-size]
goto = table[stack[-1][0]].get(name, ("error",)) goto = table.states[stack[-1][0]].get(name, parser.Error())
assert goto[0] == "goto" # Corrupt table? assert isinstance(goto, parser.Goto)
stack.append((goto[1], value)) stack.append((goto.state, value))
elif action[0] == "shift": case parser.Shift(state):
stack.append((action[1], (current_token, ()))) stack.append((state, current_token))
input_index += 1 input_index += 1
elif action[0] == "error": case parser.Error():
if input_index >= len(tokens.tokens): if input_index >= len(tokens.tokens):
raise ValueError("Unexpected end of file") message = "Unexpected end of file"
start = tokens.tokens[-1][1]
else: else:
message = f"Syntax error: unexpected symbol {current_token}"
(_, start, _) = tokens.tokens[input_index] (_, start, _) = tokens.tokens[input_index]
line_index = bisect.bisect_left(tokens.lines, start) line_index = bisect.bisect_left(tokens.lines, start)
if line_index == 0: if line_index == 0:
col_start = 0 col_start = 0
@ -85,12 +104,11 @@ def parse(table, tokens, trace=None):
column_index = start - col_start column_index = start - col_start
line_index += 1 line_index += 1
return ( error = f"{line_index}:{column_index}: {message}"
None, return (None, [error])
[
f"{line_index}:{column_index}: Syntax error: unexpected symbol {current_token}" case _:
], raise ValueError(f"Unknown action type: {action}")
)
# https://en.wikipedia.org/wiki/ANSI_escape_code # https://en.wikipedia.org/wiki/ANSI_escape_code
@ -138,6 +156,8 @@ def leave_alt_screen():
class Harness: class Harness:
source: str | None source: str | None
table: parser.ParseTable | None
tree: Tree | None
def __init__(self, lexer_func, grammar_func, start_rule, source_path): def __init__(self, lexer_func, grammar_func, start_rule, source_path):
# self.generator = parser.GenerateLR1 # self.generator = parser.GenerateLR1
@ -168,6 +188,7 @@ class Harness:
self.table = self.grammar_func().build_table( self.table = self.grammar_func().build_table(
start=self.start_rule, generator=self.generator start=self.start_rule, generator=self.generator
) )
assert self.table is not None
if self.tokens is None: if self.tokens is None:
with open(self.source_path, "r", encoding="utf-8") as f: with open(self.source_path, "r", encoding="utf-8") as f:
@ -184,9 +205,10 @@ class Harness:
sys.stdout.buffer.write(CLEAR) sys.stdout.buffer.write(CLEAR)
rows, cols = termios.tcgetwinsize(sys.stdout.fileno()) rows, cols = termios.tcgetwinsize(sys.stdout.fileno())
average_entries = sum(len(row) for row in self.table) / len(self.table) states = self.table.states
max_entries = max(len(row) for row in self.table) average_entries = sum(len(row) for row in states) / len(states)
print(f"{len(self.table)} states - {average_entries} average, {max_entries} max\r") max_entries = max(len(row) for row in states)
print(f"{len(states)} states - {average_entries} average, {max_entries} max\r")
if self.tree is not None: if self.tree is not None:
lines = [] lines = []
@ -197,11 +219,15 @@ class Harness:
sys.stdout.flush() sys.stdout.flush()
sys.stdout.buffer.flush() sys.stdout.buffer.flush()
def format_node(self, lines, node, indent=0): def format_node(self, lines, node: Tree | str, indent=0):
"""Print out an indented concrete syntax tree, from parse().""" """Print out an indented concrete syntax tree, from parse()."""
lines.append((" " * indent) + node[0]) match node:
for child in node[1]: case Tree(name, children):
lines.append((" " * indent) + (name or "???"))
for child in children:
self.format_node(lines, child, indent + 2) self.format_node(lines, child, indent + 2)
case _:
lines.append((" " * indent) + str(node))
if __name__ == "__main__": if __name__ == "__main__":

176
parser.py
View file

@ -393,13 +393,45 @@ class Assoc(enum.Enum):
RIGHT = 2 RIGHT = 2
@dataclasses.dataclass
class Action:
pass
@dataclasses.dataclass
class Reduce(Action):
name: str
count: int
transparent: bool
@dataclasses.dataclass
class Shift(Action):
state: int
@dataclasses.dataclass
class Goto(Action):
state: int
@dataclasses.dataclass
class Accept(Action):
pass
@dataclasses.dataclass
class Error(Action):
pass
class ErrorCollection: class ErrorCollection:
"""A collection of errors. The errors are grouped by config set and alphabet """A collection of errors. The errors are grouped by config set and alphabet
symbol, so that we can group the error strings appropriately when we format symbol, so that we can group the error strings appropriately when we format
the error. the error.
""" """
errors: dict[ConfigSet, dict[int, dict[Configuration, typing.Tuple]]] errors: dict[ConfigSet, dict[int, dict[Configuration, Action]]]
def __init__(self): def __init__(self):
self.errors = {} self.errors = {}
@ -413,7 +445,7 @@ class ErrorCollection:
config_set: ConfigSet, config_set: ConfigSet,
symbol: int, symbol: int,
config: Configuration, config: Configuration,
action: typing.Tuple, action: Action,
): ):
"""Add an error to the collection. """Add an error to the collection.
@ -470,14 +502,16 @@ class ErrorCollection:
if config.next is None: if config.next is None:
rule += " *" rule += " *"
if action[0] == "reduce": match action:
action_str = f"pop {action[2]} values off the stack and make a {action[1]}" case Reduce(name=name, count=count, transparent=transparent):
elif action[0] == "shift": name_str = name if not transparent else "transparent node"
action_str = f"pop {count} values off the stack and make a {name_str}"
case Shift():
action_str = "consume the token and keep going" action_str = "consume the token and keep going"
elif action[0] == "accept": case Accept():
action_str = "accept the parse" action_str = "accept the parse"
else: case _:
assert action[0] == "goto", f"Unknown action {action[0]}" assert isinstance(action, Goto)
raise Exception("Shouldn't conflict on goto ever") raise Exception("Shouldn't conflict on goto ever")
lines.append( lines.append(
@ -489,6 +523,11 @@ class ErrorCollection:
return "\n\n".join(errors) return "\n\n".join(errors)
@dataclasses.dataclass
class ParseTable:
states: list[dict[str, Action]]
class TableBuilder(object): class TableBuilder(object):
"""A helper object to assemble actions into build parse tables. """A helper object to assemble actions into build parse tables.
@ -497,23 +536,27 @@ class TableBuilder(object):
""" """
errors: ErrorCollection errors: ErrorCollection
table: list[dict[str, typing.Tuple]] table: list[dict[str, Action]]
alphabet: list[str] alphabet: list[str]
precedence: typing.Tuple[typing.Tuple[Assoc, int], ...] precedence: typing.Tuple[typing.Tuple[Assoc, int], ...]
row: None | list[typing.Tuple[None | typing.Tuple, None | Configuration]] transparents: set[str]
row: None | list[typing.Tuple[None | Action, None | Configuration]]
def __init__( def __init__(
self, self,
alphabet: list[str], alphabet: list[str],
precedence: typing.Tuple[typing.Tuple[Assoc, int], ...], precedence: typing.Tuple[typing.Tuple[Assoc, int], ...],
transparents: set[str],
): ):
self.errors = ErrorCollection() self.errors = ErrorCollection()
self.table = [] self.table = []
self.alphabet = alphabet self.alphabet = alphabet
self.precedence = precedence self.precedence = precedence
self.transparents = transparents
self.row = None self.row = None
def flush(self, all_sets: ConfigurationSetInfo) -> list[dict[str, typing.Tuple]]: def flush(self, all_sets: ConfigurationSetInfo) -> ParseTable:
"""Finish building the table and return it. """Finish building the table and return it.
Raises ValueError if there were any conflicts during construction. Raises ValueError if there were any conflicts during construction.
@ -522,7 +565,7 @@ class TableBuilder(object):
if self.errors.any(): if self.errors.any():
errors = self.errors.format(self.alphabet, all_sets) errors = self.errors.format(self.alphabet, all_sets)
raise ValueError(f"Errors building the table:\n\n{errors}") raise ValueError(f"Errors building the table:\n\n{errors}")
return self.table return ParseTable(states=self.table)
def new_row(self, config_set: ConfigSet): def new_row(self, config_set: ConfigSet):
"""Start a new row, processing the given config set. Call this before """Start a new row, processing the given config set. Call this before
@ -541,36 +584,35 @@ class TableBuilder(object):
"""Mark a reduce of the given configuration for the given symbol in the """Mark a reduce of the given configuration for the given symbol in the
current row. current row.
""" """
action = ("reduce", self.alphabet[config.name], len(config.symbols)) name = self.alphabet[config.name]
transparent = name in self.transparents
action = Reduce(name, len(config.symbols), transparent)
self._set_table_action(symbol, action, config) self._set_table_action(symbol, action, config)
def set_table_accept(self, symbol: int, config: Configuration): def set_table_accept(self, symbol: int, config: Configuration):
"""Mark a accept of the given configuration for the given symbol in the """Mark a accept of the given configuration for the given symbol in the
current row. current row.
""" """
action = ("accept",) self._set_table_action(symbol, Accept(), config)
self._set_table_action(symbol, action, config)
def set_table_shift(self, symbol: int, index: int, config: Configuration): def set_table_shift(self, symbol: int, index: int, config: Configuration):
"""Mark a shift in the current row of the given given symbol to the """Mark a shift in the current row of the given given symbol to the
given index. The configuration here provides debugging informtion for given index. The configuration here provides debugging informtion for
conflicts. conflicts.
""" """
action = ("shift", index) self._set_table_action(symbol, Shift(index), config)
self._set_table_action(symbol, action, config)
def set_table_goto(self, symbol: int, index: int): def set_table_goto(self, symbol: int, index: int):
"""Set the goto for the given nonterminal symbol in the current row.""" """Set the goto for the given nonterminal symbol in the current row."""
action = ("goto", index) self._set_table_action(symbol, Goto(index), None)
self._set_table_action(symbol, action, None)
def _action_precedence(self, symbol: int, action: typing.Tuple, config: Configuration): def _action_precedence(self, symbol: int, action: Action, config: Configuration):
if action[0] == "shift": if isinstance(action, Shift):
return self.precedence[symbol] return self.precedence[symbol]
else: else:
return self.precedence[config.name] return self.precedence[config.name]
def _set_table_action(self, symbol_id: int, action: typing.Tuple, config: Configuration | None): def _set_table_action(self, symbol_id: int, action: Action, config: Configuration | None):
"""Set the action for 'symbol' in the table row to 'action'. """Set the action for 'symbol' in the table row to 'action'.
This is destructive; it changes the table. It records an error if This is destructive; it changes the table. It records an error if
@ -607,17 +649,17 @@ class TableBuilder(object):
resolved = False resolved = False
if assoc == Assoc.LEFT: if assoc == Assoc.LEFT:
# Prefer reduce over shift # Prefer reduce over shift
if action[0] == "shift" and existing[0] == "reduce": if isinstance(action, Shift) and isinstance(existing, Reduce):
action = existing action = existing
resolved = True resolved = True
elif action[0] == "reduce" and existing[0] == "shift": elif isinstance(action, Reduce) and isinstance(existing, Shift):
resolved = True resolved = True
elif assoc == Assoc.RIGHT: elif assoc == Assoc.RIGHT:
# Prefer shift over reduce # Prefer shift over reduce
if action[0] == "shift" and existing[0] == "reduce": if isinstance(action, Shift) and isinstance(existing, Reduce):
resolved = True resolved = True
elif action[0] == "reduce" and existing[0] == "shift": elif isinstance(action, Reduce) and isinstance(existing, Shift):
action = existing action = existing
resolved = True resolved = True
@ -636,7 +678,7 @@ class TableBuilder(object):
self.row[symbol_id] = (action, config) self.row[symbol_id] = (action, config)
class GenerateLR0(object): class GenerateLR0:
"""Generate parser tables for an LR0 parser.""" """Generate parser tables for an LR0 parser."""
# Internally we use integers as symbols, not strings. Mostly this is fine, # Internally we use integers as symbols, not strings. Mostly this is fine,
@ -659,6 +701,10 @@ class GenerateLR0(object):
# for a symbol, then its entry in this tuple will be (NONE, 0). # for a symbol, then its entry in this tuple will be (NONE, 0).
precedence: typing.Tuple[typing.Tuple[Assoc, int], ...] precedence: typing.Tuple[typing.Tuple[Assoc, int], ...]
# The set of symbols for which we should reduce "transparently." This doesn't
# affect state generation at all, only the generation of the final table.
transparents: set[str]
# The lookup that maps a particular symbol to an integer. (Only really used # The lookup that maps a particular symbol to an integer. (Only really used
# for debugging.) # for debugging.)
symbol_key: dict[str, int] symbol_key: dict[str, int]
@ -675,6 +721,7 @@ class GenerateLR0(object):
start: str, start: str,
grammar: list[typing.Tuple[str, list[str]]], grammar: list[typing.Tuple[str, list[str]]],
precedence: None | dict[str, typing.Tuple[Assoc, int]] = None, precedence: None | dict[str, typing.Tuple[Assoc, int]] = None,
transparents: None | set[str] = None,
): ):
"""Initialize the parser generator with the specified grammar and """Initialize the parser generator with the specified grammar and
start symbol. start symbol.
@ -777,6 +824,10 @@ class GenerateLR0(object):
precedence = {} precedence = {}
self.precedence = tuple(precedence.get(a, (Assoc.NONE, 0)) for a in self.alphabet) self.precedence = tuple(precedence.get(a, (Assoc.NONE, 0)) for a in self.alphabet)
if transparents is None:
transparents = set()
self.transparents = transparents
self.symbol_key = symbol_key self.symbol_key = symbol_key
self.start_symbol = start_symbol self.start_symbol = start_symbol
self.end_symbol = end_symbol self.end_symbol = end_symbol
@ -903,7 +954,7 @@ class GenerateLR0(object):
del config del config
return [index for index, value in enumerate(self.terminal) if value] return [index for index, value in enumerate(self.terminal) if value]
def gen_table(self): def gen_table(self) -> ParseTable:
"""Generate the parse table. """Generate the parse table.
The parse table is a list of states. The first state in the list is The parse table is a list of states. The first state in the list is
@ -932,7 +983,7 @@ class GenerateLR0(object):
Anything missing from the row indicates an error. Anything missing from the row indicates an error.
""" """
config_sets = self.gen_all_sets() config_sets = self.gen_all_sets()
builder = TableBuilder(self.alphabet, self.precedence) builder = TableBuilder(self.alphabet, self.precedence, self.transparents)
for config_set_id, config_set in enumerate(config_sets.sets): for config_set_id, config_set in enumerate(config_sets.sets):
builder.new_row(config_set) builder.new_row(config_set)
@ -959,7 +1010,7 @@ class GenerateLR0(object):
return builder.flush(config_sets) return builder.flush(config_sets)
def parse(table, input, trace=False): def parse(table: ParseTable, input, trace=False):
"""Parse the input with the generated parsing table and return the """Parse the input with the generated parsing table and return the
concrete syntax tree. concrete syntax tree.
@ -985,7 +1036,7 @@ def parse(table, input, trace=False):
current_state = stack[-1][0] current_state = stack[-1][0]
current_token = input[input_index] current_token = input[input_index]
action = table[current_state].get(current_token, ("error",)) action = table.states[current_state].get(current_token, Error())
if trace: if trace:
print( print(
"{stack: <20} {input: <50} {action: <5}".format( "{stack: <20} {input: <50} {action: <5}".format(
@ -995,25 +1046,30 @@ def parse(table, input, trace=False):
) )
) )
if action[0] == "accept": match action:
case Accept():
return stack[-1][1] return stack[-1][1]
elif action[0] == "reduce": case Reduce(name=name, count=size, transparent=transparent):
name = action[1] children = []
size = action[2] for _, c in stack[-size:]:
if isinstance(c, tuple) and c[0] is None:
children.extend(c[1])
else:
children.append(c)
value = (name, tuple(s[1] for s in stack[-size:])) value = (name if not transparent else None, tuple(children))
stack = stack[:-size] stack = stack[:-size]
goto = table[stack[-1][0]].get(name, ("error",)) goto = table.states[stack[-1][0]].get(name, Error())
assert goto[0] == "goto" # Corrupt table? assert isinstance(goto, Goto)
stack.append((goto[1], value)) stack.append((goto.state, value))
elif action[0] == "shift": case Shift(state):
stack.append((action[1], (current_token, ()))) stack.append((state, (current_token, ())))
input_index += 1 input_index += 1
elif action[0] == "error": case Error():
raise ValueError( raise ValueError(
"Syntax error: unexpected symbol {sym}".format( "Syntax error: unexpected symbol {sym}".format(
sym=current_token, sym=current_token,
@ -1539,7 +1595,16 @@ class NonTerminal(Rule):
grammar class. grammar class.
""" """
def __init__(self, fn: typing.Callable[["Grammar"], Rule], name: str | None = None): fn: typing.Callable[["Grammar"], Rule]
name: str
transparent: bool
def __init__(
self,
fn: typing.Callable[["Grammar"], Rule],
name: str | None = None,
transparent: bool = False,
):
"""Create a new NonTerminal. """Create a new NonTerminal.
`fn` is the function that will yield the `Rule` which is the `fn` is the function that will yield the `Rule` which is the
@ -1549,6 +1614,7 @@ class NonTerminal(Rule):
""" """
self.fn = fn self.fn = fn
self.name = name or fn.__name__ self.name = name or fn.__name__
self.transparent = transparent
def generate_body(self, grammar) -> list[list[str | Token]]: def generate_body(self, grammar) -> list[list[str | Token]]:
"""Generate the body of the non-terminal. """Generate the body of the non-terminal.
@ -1638,7 +1704,8 @@ def rule(f: typing.Callable) -> Rule:
of the nonterminal, which defaults to the name of the function. of the nonterminal, which defaults to the name of the function.
""" """
name = f.__name__ name = f.__name__
return NonTerminal(f, name) transparent = name.startswith("_")
return NonTerminal(f, name, transparent)
PrecedenceList = list[typing.Tuple[Assoc, list[Rule]]] PrecedenceList = list[typing.Tuple[Assoc, list[Rule]]]
@ -1689,7 +1756,9 @@ class Grammar:
self._precedence = precedence_table self._precedence = precedence_table
def generate_nonterminal_dict(self, start: str) -> dict[str, list[list[str | Token]]]: def generate_nonterminal_dict(
self, start: str
) -> typing.Tuple[dict[str, list[list[str | Token]]], set[str]]:
"""Convert the rules into a dictionary of productions. """Convert the rules into a dictionary of productions.
Our table generators work on a very flat set of productions. This is the Our table generators work on a very flat set of productions. This is the
@ -1700,6 +1769,7 @@ class Grammar:
""" """
rules = inspect.getmembers(self, lambda x: isinstance(x, NonTerminal)) rules = inspect.getmembers(self, lambda x: isinstance(x, NonTerminal))
nonterminals = {rule.name: rule for _, rule in rules} nonterminals = {rule.name: rule for _, rule in rules}
transparents = {rule.name for _, rule in rules if rule.transparent}
grammar = {} grammar = {}
@ -1724,9 +1794,9 @@ class Grammar:
grammar[rule.name] = body grammar[rule.name] = body
return grammar return (grammar, transparents)
def desugar(self, start: str) -> list[typing.Tuple[str, list[str]]]: def desugar(self, start: str) -> typing.Tuple[list[typing.Tuple[str, list[str]]], set[str]]:
"""Convert the rules into a flat list of productions. """Convert the rules into a flat list of productions.
Our table generators work from a very flat set of productions. The form Our table generators work from a very flat set of productions. The form
@ -1734,7 +1804,7 @@ class Grammar:
generate_nonterminal_dict- less useful to people, probably, but it is generate_nonterminal_dict- less useful to people, probably, but it is
the input form needed by the Generator. the input form needed by the Generator.
""" """
temp_grammar = self.generate_nonterminal_dict(start) temp_grammar, transparents = self.generate_nonterminal_dict(start)
grammar = [] grammar = []
for rule_name, clauses in temp_grammar.items(): for rule_name, clauses in temp_grammar.items():
@ -1748,15 +1818,15 @@ class Grammar:
grammar.append((rule_name, new_clause)) grammar.append((rule_name, new_clause))
return grammar return grammar, transparents
def build_table(self, start: str, generator=GenerateLALR): def build_table(self, start: str, generator=GenerateLALR):
"""Construct a parse table for this grammar, starting at the named """Construct a parse table for this grammar, starting at the named
nonterminal rule. nonterminal rule.
""" """
desugared = self.desugar(start) desugared, transparents = self.desugar(start)
gen = generator(start, desugared, precedence=self._precedence) gen = generator(start, desugared, precedence=self._precedence, transparents=transparents)
table = gen.gen_table() table = gen.gen_table()
return table return table
@ -1772,7 +1842,7 @@ def format_node(node):
return "\n".join(lines) return "\n".join(lines)
def format_table(generator, table): def format_table(generator, table: ParseTable):
"""Format a parser table so pretty.""" """Format a parser table so pretty."""
def format_action(state, terminal): def format_action(state, terminal):
@ -1806,7 +1876,7 @@ def format_table(generator, table):
), ),
gotos=" ".join("{0: <5}".format(row.get(nt, ("error", ""))[1]) for nt in nonterminals), gotos=" ".join("{0: <5}".format(row.get(nt, ("error", ""))[1]) for nt in nonterminals),
) )
for i, row in enumerate(table) for i, row in enumerate(table.states)
] ]
return "\n".join(lines) return "\n".join(lines)