faster: Type annotation and grammar -> dict

This commit is contained in:
John Doty 2024-04-14 09:10:07 -07:00
parent 7ea47075ea
commit f4ac1a4cc0

View file

@ -5,30 +5,34 @@ This version has some performance work done.
2023 2023
""" """
import dataclasses
import functools import functools
from collections import namedtuple import typing
############################################################################### ###############################################################################
# LR0 # LR0
# #
# We start with LR0 parsers, because they form the basis of everything else. # We start with LR0 parsers, because they form the basis of everything else.
############################################################################### ###############################################################################
class Configuration( @dataclasses.dataclass(frozen=True)
namedtuple('Configuration', ['name', 'symbols', 'position', 'lookahead']) class Configuration:
):
"""A rule being tracked in a state. """A rule being tracked in a state.
(Note: technically, lookahead isn't used until we get to LR(1) parsers, (Note: technically, lookahead isn't used until we get to LR(1) parsers,
but if left at its default it's harmless. Ignore it until you get to but if left at its default it's harmless. Ignore it until you get to
the part about LR(1).) the part about LR(1).)
""" """
__slots__ = () name: str
symbols: typing.Tuple[str, ...]
position: int
lookahead: typing.Tuple[str, ...]
@classmethod @classmethod
def from_rule(cls, rule, lookahead=()): def from_rule(cls, name: str, symbols: typing.Tuple[str, ...], lookahead=()):
return Configuration( return Configuration(
name=rule[0], name=name,
symbols=rule[1], symbols=symbols,
position=0, position=0,
lookahead=lookahead, lookahead=lookahead,
) )
@ -49,7 +53,7 @@ class Configuration(
return self.next == symbol return self.next == symbol
def replace(self, **kwargs): def replace(self, **kwargs):
return self._replace(**kwargs) return dataclasses.replace(self, **kwargs)
def __str__(self): def __str__(self):
la = ", " + str(self.lookahead) if self.lookahead != () else "" la = ", " + str(self.lookahead) if self.lookahead != () else ""
@ -62,6 +66,8 @@ class Configuration(
lookahead=la, lookahead=la,
) )
ConfigSet = typing.Tuple[Configuration, ...]
class TableBuilder(object): class TableBuilder(object):
def __init__(self): def __init__(self):
self.errors = [] self.errors = []
@ -107,6 +113,7 @@ class TableBuilder(object):
This is destructive; it changes the table. It raises an error if This is destructive; it changes the table. It raises an error if
there is already an action for the symbol in the row. there is already an action for the symbol in the row.
""" """
assert self.row is not None
existing, existing_config = self.row.get(symbol, (None, None)) existing, existing_config = self.row.get(symbol, (None, None))
if existing is not None and existing != action: if existing is not None and existing != action:
config_old = str(existing_config) config_old = str(existing_config)
@ -127,9 +134,6 @@ class TableBuilder(object):
self.errors.append(error) self.errors.append(error)
self.row[symbol] = (action, config) self.row[symbol] = (action, config)
def get_table_action(self, symbol):
return self.row[symbol][0]
class GenerateLR0(object): class GenerateLR0(object):
"""Generate parser tables for an LR0 parser. """Generate parser tables for an LR0 parser.
@ -172,31 +176,39 @@ class GenerateLR0(object):
long way by memoizing results, which is much easier if we have tuples long way by memoizing results, which is much easier if we have tuples
everywhere.) everywhere.)
""" """
def __init__(self, start, grammar):
grammar: dict[str, list[typing.Tuple[str, ...]]]
nonterminals: set[str]
terminals: set[str]
alphabet: list[str]
def __init__(self, start: str, grammar: list[typing.Tuple[str, list[str]]]):
"""Initialize the parser generator with the specified grammar and """Initialize the parser generator with the specified grammar and
start symbol. start symbol.
""" """
# We always store the "augmented" grammar, which contains an initial
# production for the start state. grammar[0] is always the start
# rule, and in the set of states and table and whatever the first
# element is always the starting state/position.
self.grammar = [('__start', [start])] + grammar
# Convert the grammar into fully immutable tuples so we can hash # Turn the incoming grammar into a dictionary, indexed by nonterminal.
# everything. #
self.grammar = tuple( # We count on python dictionaries retaining the insertion order, like
(name, tuple(symbols)) # it or not.
for name, symbols in self.grammar full_grammar = {}
) for name, rule in grammar:
rules = full_grammar.get(name)
if rules is None:
rules = []
full_grammar[name] = rules
rules.append(tuple(rule))
self.grammar = full_grammar
self.nonterminals = {rule[0] for rule in grammar}
self.nonterminals = set(self.grammar.keys())
self.terminals = { self.terminals = {
sym sym
for name, symbols in grammar for _, symbols in grammar
for sym in symbols for sym in symbols
if sym not in self.nonterminals if sym not in self.nonterminals
} }
self.alphabet = self.terminals | self.nonterminals self.alphabet = list(sorted(self.terminals | self.nonterminals))
# Check to make sure they didn't use anything that will give us # Check to make sure they didn't use anything that will give us
# heartburn later. # heartburn later.
@ -209,11 +221,14 @@ class GenerateLR0(object):
) )
) )
self.grammar['__start'] = [(start,)]
self.terminals.add('$') self.terminals.add('$')
self.alphabet.add('$') self.alphabet.append('$')
@functools.cache @functools.cache
def gen_closure_next(self, config): def gen_closure_next(self, config: Configuration):
"""Return the next set of configurations in the closure for """Return the next set of configurations in the closure for
config. config.
@ -223,19 +238,25 @@ class GenerateLR0(object):
beginning. (If the position for config is just before a terminal, beginning. (If the position for config is just before a terminal,
or at the end of the production, then the next set is empty.) or at the end of the production, then the next set is empty.)
""" """
if config.at_end: next = config.next
if next is None:
return () return ()
else: else:
return tuple( return tuple(
Configuration.from_rule(rule) Configuration.from_rule(next, rule)
for rule in self.grammar for rule in self.grammar.get(next, ())
if rule[0] == config.next
) )
@functools.cache @functools.cache
def gen_closure(self, seeds): def gen_closure(self, seeds: typing.Iterable[Configuration]) -> ConfigSet:
"""Compute the closure for the specified configs. We have replaced a """Compute the closure for the specified configs. The closure is all
recursive version with an iterative one.""" of the configurations we could be in. Specifically, if the position
for a config is just before a non-terminal then we must also consider
configurations where the rule is the rule for the non-terminal and
the position is just before the beginning of the rule.
(We have replaced a recursive version with an iterative one.)
"""
closure = set() closure = set()
pending = list(seeds) pending = list(seeds)
while len(pending) > 0: while len(pending) > 0:
@ -250,7 +271,7 @@ class GenerateLR0(object):
return tuple(closure) # TODO: Why tuple? return tuple(closure) # TODO: Why tuple?
@functools.cache @functools.cache
def gen_successor(self, config_set, symbol): def gen_successor(self, config_set: typing.Iterable[Configuration], symbol: str) -> ConfigSet:
"""Compute the successor state for the given config set and the """Compute the successor state for the given config set and the
given symbol. given symbol.
@ -266,7 +287,7 @@ class GenerateLR0(object):
closure = self.gen_closure(seeds) closure = self.gen_closure(seeds)
return closure return closure
def gen_all_successors(self, config_set): def gen_all_successors(self, config_set: typing.Iterable[Configuration]) -> list[ConfigSet]:
"""Return all of the non-empty successors for the given config set.""" """Return all of the non-empty successors for the given config set."""
next = [] next = []
for symbol in self.alphabet: for symbol in self.alphabet:
@ -274,12 +295,10 @@ class GenerateLR0(object):
if len(successor) > 0: if len(successor) > 0:
next.append(successor) next.append(successor)
return tuple(next) return next
def gen_sets(self, config_set): def gen_sets(self, config_set: typing.Tuple[Configuration,...]) -> typing.Tuple[ConfigSet, ...]:
"""Recursively generate all configuration sets starting from the """Generate all configuration sets starting from the provided set."""
provided set, and merge them with the provided set 'F'.
"""
# NOTE: Not a set because we need to maintain insertion order! # NOTE: Not a set because we need to maintain insertion order!
# The first element in the dictionary needs to be the intial # The first element in the dictionary needs to be the intial
# set. # set.
@ -298,11 +317,13 @@ class GenerateLR0(object):
return tuple(F.keys()) return tuple(F.keys())
def gen_all_sets(self): def gen_all_sets(self) -> typing.Tuple[ConfigSet, ...]:
"""Generate all of the configuration sets for the grammar.""" """Generate all of the configuration sets for the grammar."""
initial_set = self.gen_closure( seeds = tuple(
( Configuration.from_rule(self.grammar[0]), ) Configuration.from_rule('__start', rule)
for rule in self.grammar['__start']
) )
initial_set = self.gen_closure(seeds)
return self.gen_sets(initial_set) return self.gen_sets(initial_set)
def find_set_index(self, sets, set): def find_set_index(self, sets, set):
@ -314,11 +335,12 @@ class GenerateLR0(object):
return i return i
return None return None
def gen_reduce_set(self, config): def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]:
"""Return the set of symbols that indicate we should reduce the given """Return the set of symbols that indicate we should reduce the given
configuration. configuration.
In an LR0 parser, this is just the set of all terminals.""" In an LR0 parser, this is just the set of all terminals."""
del(config)
return self.terminals return self.terminals
def gen_table(self): def gen_table(self):
@ -406,9 +428,6 @@ class GenerateLR0(object):
errors.append(error) errors.append(error)
row[symbol] = (action, config) row[symbol] = (action, config)
def get_table_action(self, row, symbol):
return row[symbol][0]
def parse(table, input, trace=False): def parse(table, input, trace=False):
"""Parse the input with the generated parsing table and return the """Parse the input with the generated parsing table and return the
@ -428,7 +447,7 @@ def parse(table, input, trace=False):
# Our stack is a stack of tuples, where the first entry is the state number # Our stack is a stack of tuples, where the first entry is the state number
# and the second entry is the 'value' that was generated when the state was # and the second entry is the 'value' that was generated when the state was
# pushed. # pushed.
stack = [(0, None)] stack : list[typing.Tuple[int, typing.Any]] = [(0, None)]
while True: while True:
current_state = stack[-1][0] current_state = stack[-1][0]
current_token = input[input_index] current_token = input[input_index]
@ -483,11 +502,13 @@ class GenerateSLR1(GenerateLR0):
means they need to know how to generate 'first(A)', which is most of the means they need to know how to generate 'first(A)', which is most of the
code in this class. code in this class.
""" """
_first_symbol_cache: dict[str, typing.Tuple[str|None, ...]]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._first_symbol_cache = {} self._first_symbol_cache = {}
def gen_first_symbol(self, symbol, visited): def gen_first_symbol(self, symbol: str, visited: set[str]) -> typing.Tuple[str|None, ...]:
"""Compute the first set for a single symbol. """Compute the first set for a single symbol.
If a symbol can be empty, then the set contains epsilon, which we If a symbol can be empty, then the set contains epsilon, which we
@ -517,9 +538,8 @@ class GenerateSLR1(GenerateLR0):
# All the firsts from all the productions. # All the firsts from all the productions.
firsts = [ firsts = [
self.gen_first(rule[1], visited) self.gen_first(rule, visited)
for rule in self.grammar for rule in self.grammar.get(symbol, ())
if rule[0] == symbol
] ]
result = {f for fs in firsts for f in fs} result = {f for fs in firsts for f in fs}
@ -527,7 +547,10 @@ class GenerateSLR1(GenerateLR0):
self._first_symbol_cache[symbol] = result self._first_symbol_cache[symbol] = result
return result return result
def gen_first(self, symbols, visited=None): # TODO: Cache
# TODO: Use sets man
# TODO: Iterative not recursive
def gen_first(self, symbols: typing.Tuple[str, ...], visited=None) -> typing.Tuple[str|None, ...]:
"""Compute the first set for a sequence of symbols. """Compute the first set for a sequence of symbols.
The first set is the set of tokens that can appear as the first token The first set is the set of tokens that can appear as the first token
@ -557,7 +580,7 @@ class GenerateSLR1(GenerateLR0):
result = tuple(sorted(set(result), key=lambda x: (x is None, x))) result = tuple(sorted(set(result), key=lambda x: (x is None, x)))
return result return result
def gen_follow(self, symbol, visited=None): def gen_follow(self, symbol: str, visited=None):
"""Generate the follow set for the given nonterminal. """Generate the follow set for the given nonterminal.
The follow set for a nonterminal is the set of terminals that can The follow set for a nonterminal is the set of terminals that can
@ -578,20 +601,21 @@ class GenerateSLR1(GenerateLR0):
visited.add(symbol) visited.add(symbol)
follow = () follow = ()
for production in self.grammar: for name, rule_set in self.grammar.items():
for index, prod_symbol in enumerate(production[1]): for production in rule_set:
if prod_symbol != symbol: for index, prod_symbol in enumerate(production):
continue if prod_symbol != symbol:
continue
first = self.gen_first(production[1][index+1:]) first = self.gen_first(production[index+1:])
follow = follow + tuple(f for f in first if f is not None) follow = follow + tuple(f for f in first if f is not None)
if None in first: if None in first:
follow = follow + self.gen_follow(production[0], visited) follow = follow + self.gen_follow(name, visited)
assert None not in follow # Should always ground out at __start assert None not in follow # Should always ground out at __start
return follow return follow
def gen_reduce_set(self, config): def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]:
"""Return the set of symbols that indicate we should reduce the given """Return the set of symbols that indicate we should reduce the given
config. config.
@ -610,14 +634,15 @@ class GenerateLR1(GenerateSLR1):
details. (Except for the start configuration, which has '$' as its details. (Except for the start configuration, which has '$' as its
lookahead.) lookahead.)
""" """
def gen_reduce_set(self, config): def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]:
"""Return the set of symbols that indicate we should reduce the given """Return the set of symbols that indicate we should reduce the given
config. config.
In an LR1 parser, this is the lookahead of the configuration.""" In an LR1 parser, this is the lookahead of the configuration."""
return config.lookahead return config.lookahead
def gen_closure_next(self, config): @functools.cache
def gen_closure_next(self, config: Configuration):
"""Return the next set of configurations in the closure for """Return the next set of configurations in the closure for
config. config.
@ -632,14 +657,12 @@ class GenerateLR1(GenerateSLR1):
(See the documentation in GenerateLR0 for more information on how (See the documentation in GenerateLR0 for more information on how
this function fits into the whole process.) this function fits into the whole process.)
""" """
if config.at_end: config_next = config.next
if config_next is None:
return () return ()
else: else:
next = [] next = []
for rule in self.grammar: for rule in self.grammar.get(config_next, ()):
if rule[0] != config.next:
continue
# N.B.: We can't just append config.lookahead to config.rest # N.B.: We can't just append config.lookahead to config.rest
# and compute first(), because lookahead is a *set*. So # and compute first(), because lookahead is a *set*. So
# in this case we just say if 'first' contains epsilon, # in this case we just say if 'first' contains epsilon,
@ -650,7 +673,7 @@ class GenerateLR1(GenerateSLR1):
lookahead = tuple(l for l in lookahead if l is not None) lookahead = tuple(l for l in lookahead if l is not None)
lookahead = lookahead + config.lookahead lookahead = lookahead + config.lookahead
lookahead = tuple(sorted(set(lookahead))) lookahead = tuple(sorted(set(lookahead)))
next.append(Configuration.from_rule(rule, lookahead=lookahead)) next.append(Configuration.from_rule(config_next, rule, lookahead=lookahead))
return tuple(next) return tuple(next)
@ -660,9 +683,11 @@ class GenerateLR1(GenerateSLR1):
In LR1 parsers, we must remember to set the lookahead of the start In LR1 parsers, we must remember to set the lookahead of the start
symbol to '$'. symbol to '$'.
""" """
initial_set = self.gen_closure( seeds = tuple(
( Configuration.from_rule(self.grammar[0], lookahead=('$',)), ), Configuration.from_rule('__start', rule, lookahead=('$',))
for rule in self.grammar['__start']
) )
initial_set = self.gen_closure(seeds)
return self.gen_sets(initial_set) return self.gen_sets(initial_set)
@ -862,11 +887,11 @@ def examples():
print("grammar_lr0_shift_reduce (SLR1):") print("grammar_lr0_shift_reduce (SLR1):")
gen = GenerateSLR1('E', grammar_lr0_shift_reduce) gen = GenerateSLR1('E', grammar_lr0_shift_reduce)
print("First: {first}".format(first=str(gen.gen_first(['E'])))) print("First: {first}".format(first=str(gen.gen_first(('E',)))))
print("Follow: {follow}".format(follow=str(gen.gen_follow('E')))) print("Follow: {follow}".format(follow=str(gen.gen_follow('E'))))
table = gen.gen_table() table = gen.gen_table()
print(format_table(gen, table)) print(format_table(gen, table))
tree = parse(table, ['id', '+', '(', 'id', '[', 'id', ']', ')']) tree = parse(table, ['id', '+', '(', 'id', '[', 'id', ']', ')'], trace=True)
print(format_node(tree) + "\n") print(format_node(tree) + "\n")
print() print()