faster: Type annotation and grammar -> dict

This commit is contained in:
John Doty 2024-04-14 09:10:07 -07:00
parent 7ea47075ea
commit f4ac1a4cc0

View file

@ -5,30 +5,34 @@ This version has some performance work done.
2023
"""
import dataclasses
import functools
from collections import namedtuple
import typing
###############################################################################
# LR0
#
# We start with LR0 parsers, because they form the basis of everything else.
###############################################################################
class Configuration(
namedtuple('Configuration', ['name', 'symbols', 'position', 'lookahead'])
):
@dataclasses.dataclass(frozen=True)
class Configuration:
"""A rule being tracked in a state.
(Note: technically, lookahead isn't used until we get to LR(1) parsers,
but if left at its default it's harmless. Ignore it until you get to
the part about LR(1).)
"""
__slots__ = ()
name: str
symbols: typing.Tuple[str, ...]
position: int
lookahead: typing.Tuple[str, ...]
@classmethod
def from_rule(cls, rule, lookahead=()):
def from_rule(cls, name: str, symbols: typing.Tuple[str, ...], lookahead=()):
return Configuration(
name=rule[0],
symbols=rule[1],
name=name,
symbols=symbols,
position=0,
lookahead=lookahead,
)
@ -49,7 +53,7 @@ class Configuration(
return self.next == symbol
def replace(self, **kwargs):
return self._replace(**kwargs)
return dataclasses.replace(self, **kwargs)
def __str__(self):
la = ", " + str(self.lookahead) if self.lookahead != () else ""
@ -62,6 +66,8 @@ class Configuration(
lookahead=la,
)
ConfigSet = typing.Tuple[Configuration, ...]
class TableBuilder(object):
def __init__(self):
self.errors = []
@ -107,6 +113,7 @@ class TableBuilder(object):
This is destructive; it changes the table. It raises an error if
there is already an action for the symbol in the row.
"""
assert self.row is not None
existing, existing_config = self.row.get(symbol, (None, None))
if existing is not None and existing != action:
config_old = str(existing_config)
@ -127,9 +134,6 @@ class TableBuilder(object):
self.errors.append(error)
self.row[symbol] = (action, config)
def get_table_action(self, symbol):
return self.row[symbol][0]
class GenerateLR0(object):
"""Generate parser tables for an LR0 parser.
@ -172,31 +176,39 @@ class GenerateLR0(object):
long way by memoizing results, which is much easier if we have tuples
everywhere.)
"""
def __init__(self, start, grammar):
grammar: dict[str, list[typing.Tuple[str, ...]]]
nonterminals: set[str]
terminals: set[str]
alphabet: list[str]
def __init__(self, start: str, grammar: list[typing.Tuple[str, list[str]]]):
"""Initialize the parser generator with the specified grammar and
start symbol.
"""
# We always store the "augmented" grammar, which contains an initial
# production for the start state. grammar[0] is always the start
# rule, and in the set of states and table and whatever the first
# element is always the starting state/position.
self.grammar = [('__start', [start])] + grammar
# Convert the grammar into fully immutable tuples so we can hash
# everything.
self.grammar = tuple(
(name, tuple(symbols))
for name, symbols in self.grammar
)
# Turn the incoming grammar into a dictionary, indexed by nonterminal.
#
# We count on python dictionaries retaining the insertion order, like
# it or not.
full_grammar = {}
for name, rule in grammar:
rules = full_grammar.get(name)
if rules is None:
rules = []
full_grammar[name] = rules
rules.append(tuple(rule))
self.grammar = full_grammar
self.nonterminals = {rule[0] for rule in grammar}
self.nonterminals = set(self.grammar.keys())
self.terminals = {
sym
for name, symbols in grammar
for _, symbols in grammar
for sym in symbols
if sym not in self.nonterminals
}
self.alphabet = self.terminals | self.nonterminals
self.alphabet = list(sorted(self.terminals | self.nonterminals))
# Check to make sure they didn't use anything that will give us
# heartburn later.
@ -209,11 +221,14 @@ class GenerateLR0(object):
)
)
self.grammar['__start'] = [(start,)]
self.terminals.add('$')
self.alphabet.add('$')
self.alphabet.append('$')
@functools.cache
def gen_closure_next(self, config):
def gen_closure_next(self, config: Configuration):
"""Return the next set of configurations in the closure for
config.
@ -223,19 +238,25 @@ class GenerateLR0(object):
beginning. (If the position for config is just before a terminal,
or at the end of the production, then the next set is empty.)
"""
if config.at_end:
next = config.next
if next is None:
return ()
else:
return tuple(
Configuration.from_rule(rule)
for rule in self.grammar
if rule[0] == config.next
Configuration.from_rule(next, rule)
for rule in self.grammar.get(next, ())
)
@functools.cache
def gen_closure(self, seeds):
"""Compute the closure for the specified configs. We have replaced a
recursive version with an iterative one."""
def gen_closure(self, seeds: typing.Iterable[Configuration]) -> ConfigSet:
"""Compute the closure for the specified configs. The closure is all
of the configurations we could be in. Specifically, if the position
for a config is just before a non-terminal then we must also consider
configurations where the rule is the rule for the non-terminal and
the position is just before the beginning of the rule.
(We have replaced a recursive version with an iterative one.)
"""
closure = set()
pending = list(seeds)
while len(pending) > 0:
@ -250,7 +271,7 @@ class GenerateLR0(object):
return tuple(closure) # TODO: Why tuple?
@functools.cache
def gen_successor(self, config_set, symbol):
def gen_successor(self, config_set: typing.Iterable[Configuration], symbol: str) -> ConfigSet:
"""Compute the successor state for the given config set and the
given symbol.
@ -266,7 +287,7 @@ class GenerateLR0(object):
closure = self.gen_closure(seeds)
return closure
def gen_all_successors(self, config_set):
def gen_all_successors(self, config_set: typing.Iterable[Configuration]) -> list[ConfigSet]:
"""Return all of the non-empty successors for the given config set."""
next = []
for symbol in self.alphabet:
@ -274,12 +295,10 @@ class GenerateLR0(object):
if len(successor) > 0:
next.append(successor)
return tuple(next)
return next
def gen_sets(self, config_set):
"""Recursively generate all configuration sets starting from the
provided set, and merge them with the provided set 'F'.
"""
def gen_sets(self, config_set: typing.Tuple[Configuration,...]) -> typing.Tuple[ConfigSet, ...]:
"""Generate all configuration sets starting from the provided set."""
# NOTE: Not a set because we need to maintain insertion order!
# The first element in the dictionary needs to be the intial
# set.
@ -298,11 +317,13 @@ class GenerateLR0(object):
return tuple(F.keys())
def gen_all_sets(self):
def gen_all_sets(self) -> typing.Tuple[ConfigSet, ...]:
"""Generate all of the configuration sets for the grammar."""
initial_set = self.gen_closure(
( Configuration.from_rule(self.grammar[0]), )
seeds = tuple(
Configuration.from_rule('__start', rule)
for rule in self.grammar['__start']
)
initial_set = self.gen_closure(seeds)
return self.gen_sets(initial_set)
def find_set_index(self, sets, set):
@ -314,11 +335,12 @@ class GenerateLR0(object):
return i
return None
def gen_reduce_set(self, config):
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]:
"""Return the set of symbols that indicate we should reduce the given
configuration.
In an LR0 parser, this is just the set of all terminals."""
del(config)
return self.terminals
def gen_table(self):
@ -406,9 +428,6 @@ class GenerateLR0(object):
errors.append(error)
row[symbol] = (action, config)
def get_table_action(self, row, symbol):
return row[symbol][0]
def parse(table, input, trace=False):
"""Parse the input with the generated parsing table and return the
@ -428,7 +447,7 @@ def parse(table, input, trace=False):
# Our stack is a stack of tuples, where the first entry is the state number
# and the second entry is the 'value' that was generated when the state was
# pushed.
stack = [(0, None)]
stack : list[typing.Tuple[int, typing.Any]] = [(0, None)]
while True:
current_state = stack[-1][0]
current_token = input[input_index]
@ -483,11 +502,13 @@ class GenerateSLR1(GenerateLR0):
means they need to know how to generate 'first(A)', which is most of the
code in this class.
"""
_first_symbol_cache: dict[str, typing.Tuple[str|None, ...]]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._first_symbol_cache = {}
def gen_first_symbol(self, symbol, visited):
def gen_first_symbol(self, symbol: str, visited: set[str]) -> typing.Tuple[str|None, ...]:
"""Compute the first set for a single symbol.
If a symbol can be empty, then the set contains epsilon, which we
@ -517,9 +538,8 @@ class GenerateSLR1(GenerateLR0):
# All the firsts from all the productions.
firsts = [
self.gen_first(rule[1], visited)
for rule in self.grammar
if rule[0] == symbol
self.gen_first(rule, visited)
for rule in self.grammar.get(symbol, ())
]
result = {f for fs in firsts for f in fs}
@ -527,7 +547,10 @@ class GenerateSLR1(GenerateLR0):
self._first_symbol_cache[symbol] = result
return result
def gen_first(self, symbols, visited=None):
# TODO: Cache
# TODO: Use sets man
# TODO: Iterative not recursive
def gen_first(self, symbols: typing.Tuple[str, ...], visited=None) -> typing.Tuple[str|None, ...]:
"""Compute the first set for a sequence of symbols.
The first set is the set of tokens that can appear as the first token
@ -557,7 +580,7 @@ class GenerateSLR1(GenerateLR0):
result = tuple(sorted(set(result), key=lambda x: (x is None, x)))
return result
def gen_follow(self, symbol, visited=None):
def gen_follow(self, symbol: str, visited=None):
"""Generate the follow set for the given nonterminal.
The follow set for a nonterminal is the set of terminals that can
@ -578,20 +601,21 @@ class GenerateSLR1(GenerateLR0):
visited.add(symbol)
follow = ()
for production in self.grammar:
for index, prod_symbol in enumerate(production[1]):
for name, rule_set in self.grammar.items():
for production in rule_set:
for index, prod_symbol in enumerate(production):
if prod_symbol != symbol:
continue
first = self.gen_first(production[1][index+1:])
first = self.gen_first(production[index+1:])
follow = follow + tuple(f for f in first if f is not None)
if None in first:
follow = follow + self.gen_follow(production[0], visited)
follow = follow + self.gen_follow(name, visited)
assert None not in follow # Should always ground out at __start
return follow
def gen_reduce_set(self, config):
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]:
"""Return the set of symbols that indicate we should reduce the given
config.
@ -610,14 +634,15 @@ class GenerateLR1(GenerateSLR1):
details. (Except for the start configuration, which has '$' as its
lookahead.)
"""
def gen_reduce_set(self, config):
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]:
"""Return the set of symbols that indicate we should reduce the given
config.
In an LR1 parser, this is the lookahead of the configuration."""
return config.lookahead
def gen_closure_next(self, config):
@functools.cache
def gen_closure_next(self, config: Configuration):
"""Return the next set of configurations in the closure for
config.
@ -632,14 +657,12 @@ class GenerateLR1(GenerateSLR1):
(See the documentation in GenerateLR0 for more information on how
this function fits into the whole process.)
"""
if config.at_end:
config_next = config.next
if config_next is None:
return ()
else:
next = []
for rule in self.grammar:
if rule[0] != config.next:
continue
for rule in self.grammar.get(config_next, ()):
# N.B.: We can't just append config.lookahead to config.rest
# and compute first(), because lookahead is a *set*. So
# in this case we just say if 'first' contains epsilon,
@ -650,7 +673,7 @@ class GenerateLR1(GenerateSLR1):
lookahead = tuple(l for l in lookahead if l is not None)
lookahead = lookahead + config.lookahead
lookahead = tuple(sorted(set(lookahead)))
next.append(Configuration.from_rule(rule, lookahead=lookahead))
next.append(Configuration.from_rule(config_next, rule, lookahead=lookahead))
return tuple(next)
@ -660,9 +683,11 @@ class GenerateLR1(GenerateSLR1):
In LR1 parsers, we must remember to set the lookahead of the start
symbol to '$'.
"""
initial_set = self.gen_closure(
( Configuration.from_rule(self.grammar[0], lookahead=('$',)), ),
seeds = tuple(
Configuration.from_rule('__start', rule, lookahead=('$',))
for rule in self.grammar['__start']
)
initial_set = self.gen_closure(seeds)
return self.gen_sets(initial_set)
@ -862,11 +887,11 @@ def examples():
print("grammar_lr0_shift_reduce (SLR1):")
gen = GenerateSLR1('E', grammar_lr0_shift_reduce)
print("First: {first}".format(first=str(gen.gen_first(['E']))))
print("First: {first}".format(first=str(gen.gen_first(('E',)))))
print("Follow: {follow}".format(follow=str(gen.gen_follow('E'))))
table = gen.gen_table()
print(format_table(gen, table))
tree = parse(table, ['id', '+', '(', 'id', '[', 'id', ']', ')'])
tree = parse(table, ['id', '+', '(', 'id', '[', 'id', ']', ')'], trace=True)
print(format_node(tree) + "\n")
print()