faster: All symbols are integers

This commit is contained in:
John Doty 2024-04-15 16:05:42 -07:00
parent ee60951ffc
commit a818a4a498

View file

@ -33,11 +33,11 @@ class Configuration:
'_hash',
)
name: str
symbols: typing.Tuple[str, ...]
name: int
symbols: typing.Tuple[int, ...]
position: int
lookahead: typing.Tuple[str, ...]
next: str | None
lookahead: typing.Tuple[int, ...]
next: int | None
at_end: bool
_vals: typing.Tuple
@ -57,7 +57,7 @@ class Configuration:
self._hash = hash(self._vals)
@classmethod
def from_rule(cls, name: str, symbols: typing.Tuple[str, ...], lookahead=()):
def from_rule(cls, name: int, symbols: typing.Tuple[int, ...], lookahead=()):
return Configuration(
name=name,
symbols=symbols,
@ -122,12 +122,12 @@ class Configuration:
def rest(self):
return self.symbols[(self.position+1):]
def __str__(self):
la = ", " + str(self.lookahead) if self.lookahead != () else ""
def format(self, alphabet: list[str]) -> str:
la = ", " + str(tuple(alphabet[i] for i in self.lookahead)) if self.lookahead != () else ""
return "{name} -> {bits}{lookahead}".format(
name=self.name,
bits=' '.join([
'* ' + sym if i == self.position else sym
'* ' + alphabet[sym] if i == self.position else alphabet[sym]
for i, sym in enumerate(self.symbols)
]) + (' *' if self.at_end else ''),
lookahead=la,
@ -136,9 +136,10 @@ class Configuration:
ConfigSet = typing.Tuple[Configuration, ...]
class TableBuilder(object):
def __init__(self):
def __init__(self, alphabet: list[str]):
self.errors = []
self.table = []
self.alphabet = alphabet
self.row = None
def flush(self):
@ -149,39 +150,45 @@ class TableBuilder(object):
def new_row(self, config_set):
self._flush_row()
self.row = {}
self.row = [(None, None) for _ in self.alphabet]
self.current_config_set = config_set
def _flush_row(self):
if self.row:
actions = {k: v[0] for k,v in self.row.items()}
actions = {
self.alphabet[k]: v[0]
for k, v in enumerate(self.row)
if v[0] is not None
}
self.table.append(actions)
def set_table_reduce(self, symbol, config):
action = ('reduce', config.name, len(config.symbols))
def set_table_reduce(self, symbol: int, config):
action = ('reduce', self.alphabet[config.name], len(config.symbols))
self._set_table_action(symbol, action, config)
def set_table_accept(self, config):
def set_table_accept(self, symbol: int, config: Configuration):
action = ('accept',)
self._set_table_action('$', action, config)
self._set_table_action(symbol, action, config)
def set_table_shift(self, index, config):
def set_table_shift(self, symbol: int, index: int, config: Configuration):
action = ('shift', index)
self._set_table_action(config.next, action, config)
self._set_table_action(symbol, action, config)
def set_table_goto(self, symbol, index):
def set_table_goto(self, symbol: int, index: int):
action = ('goto', index)
self._set_table_action(symbol, action, None)
def _set_table_action(self, symbol, action, config):
def _set_table_action(self, symbol_id: int, action, config):
"""Set the action for 'symbol' in the table row to 'action'.
This is destructive; it changes the table. It raises an error if
there is already an action for the symbol in the row.
"""
assert isinstance(symbol_id, int)
assert self.row is not None
existing, existing_config = self.row.get(symbol, (None, None))
existing, existing_config = self.row[symbol_id]
if existing is not None and existing != action:
config_old = str(existing_config)
config_new = str(config)
@ -195,11 +202,11 @@ class TableBuilder(object):
max_len=max_len,
old=existing,
new=action,
symbol=symbol,
symbol=self.alphabet[symbol_id],
)
)
self.errors.append(error)
self.row[symbol] = (action, config)
self.row[symbol_id] = (action, config)
class GenerateLR0(object):
@ -244,42 +251,29 @@ class GenerateLR0(object):
everywhere.)
"""
grammar: dict[str, list[typing.Tuple[str, ...]]]
nonterminals: set[str]
terminals: set[str]
alphabet: list[str]
grammar: list[list[typing.Tuple[int, ...]]]
nonterminals: typing.Tuple[bool, ...]
terminals: typing.Tuple[bool, ...]
symbol_key: dict[str, int]
start_symbol: int
end_symbol: int
def __init__(self, start: str, grammar: list[typing.Tuple[str, list[str]]]):
"""Initialize the parser generator with the specified grammar and
start symbol.
"""
# Turn the incoming grammar into a dictionary, indexed by nonterminal.
#
# We count on python dictionaries retaining the insertion order, like
# it or not.
full_grammar = {}
# Work out the alphabet.
alphabet = set()
for name, rule in grammar:
rules = full_grammar.get(name)
if rules is None:
rules = []
full_grammar[name] = rules
rules.append(tuple(rule))
self.grammar = full_grammar
self.nonterminals = set(self.grammar.keys())
self.terminals = {
sym
for _, symbols in grammar
for sym in symbols
if sym not in self.nonterminals
}
self.alphabet = list(sorted(self.terminals | self.nonterminals))
alphabet.add(name)
alphabet.update(symbol for symbol in rule)
# Check to make sure they didn't use anything that will give us
# heartburn later.
reserved = [a for a in self.alphabet if a.startswith('__') or a == '$']
reserved = [a for a in alphabet if a.startswith('__') or a == '$']
if reserved:
raise ValueError(
"Can't use {symbols} in grammars, {what} reserved.".format(
@ -288,10 +282,54 @@ class GenerateLR0(object):
)
)
alphabet.add('__start')
alphabet.add('$')
self.alphabet = list(sorted(alphabet))
self.grammar['__start'] = [(start,)]
self.terminals.add('$')
self.alphabet.append('$')
symbol_key = {
symbol: index
for index, symbol in enumerate(self.alphabet)
}
start_symbol = symbol_key['__start']
end_symbol = symbol_key['$']
assert self.alphabet[start_symbol] == '__start'
assert self.alphabet[end_symbol] == '$'
# Turn the incoming grammar into a dictionary, indexed by nonterminal.
#
# We count on python dictionaries retaining the insertion order, like
# it or not.
full_grammar = [list() for _ in self.alphabet]
terminals = [True for _ in self.alphabet]
assert terminals[end_symbol]
nonterminals = [False for _ in self.alphabet]
for name, rule in grammar:
name_symbol = symbol_key[name]
terminals[name_symbol] = False
nonterminals[name_symbol] = True
rules = full_grammar[name_symbol]
rules.append(tuple(symbol_key[symbol] for symbol in rule))
self.grammar = full_grammar
self.grammar[start_symbol].append((symbol_key[start],))
terminals[start_symbol] = False
nonterminals[start_symbol] = True
self.terminals = tuple(terminals)
self.nonterminals = tuple(nonterminals)
assert self.terminals[end_symbol]
assert self.nonterminals[start_symbol]
self.symbol_key = symbol_key
self.start_symbol = start_symbol
self.end_symbol = end_symbol
@functools.cache
@ -311,7 +349,7 @@ class GenerateLR0(object):
else:
return tuple(
Configuration.from_rule(next, rule)
for rule in self.grammar.get(next, ())
for rule in self.grammar[next]
)
@functools.cache
@ -393,8 +431,8 @@ class GenerateLR0(object):
def gen_all_sets(self) -> typing.Tuple[ConfigSet, ...]:
"""Generate all of the configuration sets for the grammar."""
seeds = tuple(
Configuration.from_rule('__start', rule)
for rule in self.grammar['__start']
Configuration.from_rule(self.start_symbol, rule)
for rule in self.grammar[self.start_symbol]
)
initial_set = self.gen_closure(seeds)
return self.gen_sets(initial_set)
@ -408,13 +446,13 @@ class GenerateLR0(object):
"""
return sets.get(s)
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]:
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[int]:
"""Return the set of symbols that indicate we should reduce the given
configuration.
In an LR0 parser, this is just the set of all terminals."""
del(config)
return self.terminals
return [index for index, value in enumerate(self.terminals) if value]
def gen_table(self):
"""Generate the parse table.
@ -444,7 +482,7 @@ class GenerateLR0(object):
Anything missing from the row indicates an error.
"""
builder = TableBuilder()
builder = TableBuilder(self.alphabet)
config_sets = self.gen_all_sets()
set_index = self.build_set_index(config_sets)
@ -454,56 +492,31 @@ class GenerateLR0(object):
# Actions
for config in config_set:
if config.at_end:
if config.name != '__start':
config_next = config.next
if config_next is None:
if config.name != self.start_symbol:
for a in self.gen_reduce_set(config):
builder.set_table_reduce(a, config)
else:
builder.set_table_accept(config)
builder.set_table_accept(self.end_symbol, config)
else:
if config.next in self.terminals:
successor = self.gen_successor(config_set, config.next)
index = self.find_set_index(set_index, successor)
assert index is not None
builder.set_table_shift(index, config)
elif self.terminals[config_next]:
successor = self.gen_successor(config_set, config_next)
index = self.find_set_index(set_index, successor)
assert index is not None
builder.set_table_shift(config_next, index, config)
# Gotos
for symbol in self.nonterminals:
successor = self.gen_successor(config_set, symbol)
index = self.find_set_index(set_index, successor)
if index is not None:
builder.set_table_goto(symbol, index)
for symbol, is_nonterminal in enumerate(self.nonterminals):
if is_nonterminal:
successor = self.gen_successor(config_set, symbol)
index = self.find_set_index(set_index, successor)
if index is not None:
builder.set_table_goto(symbol, index)
return builder.flush()
def set_table_action(self, errors, row, symbol, action, config):
"""Set the action for 'symbol' in the table row to 'action'.
This is destructive; it changes the table. It raises an error if
there is already an action for the symbol in the row.
"""
existing, existing_config = row.get(symbol, (None, None))
if existing is not None and existing != action:
config_old = str(existing_config)
config_new = str(config)
max_len = max(len(config_old), len(config_new)) + 1
error = (
"Conflicting actions for token '{symbol}':\n"
" {config_old: <{max_len}}: {old}\n"
" {config_new: <{max_len}}: {new}\n".format(
config_old=config_old,
config_new=config_new,
max_len=max_len,
old=existing,
new=action,
symbol=symbol,
)
)
errors.append(error)
row[symbol] = (action, config)
def parse(table, input, trace=False):
"""Parse the input with the generated parsing table and return the
@ -577,45 +590,50 @@ def update_changed(items: set, other: set) -> bool:
@dataclasses.dataclass(frozen=True)
class FirstInfo:
firsts: dict[str, set[str]]
is_epsilon: set[str]
firsts: list[set[int]]
is_epsilon: list[bool]
@classmethod
def from_grammar(
cls,
grammar: dict[str, list[typing.Tuple[str,...]]],
terminals: set[str],
grammar: list[list[typing.Tuple[int,...]]],
terminals: typing.Tuple[bool, ...],
):
firsts = {name: set() for name in grammar.keys()}
for t in terminals:
firsts[t] = {t}
firsts = [set() for _ in grammar]
epsilons = set()
# Add all terminals to their own firsts
for index, is_terminal in enumerate(terminals):
if is_terminal:
firsts[index].add(index)
epsilons = [False] * len(grammar)
changed = True
while changed:
changed = False
for name, rules in grammar.items():
for name, rules in enumerate(grammar):
f = firsts[name]
for rule in rules:
if len(rule) == 0:
changed = add_changed(epsilons, name) or changed
changed = changed or not epsilons[name]
epsilons[name] = True
continue
for index, symbol in enumerate(rule):
if symbol in terminals:
if terminals[symbol]:
changed = add_changed(f, symbol) or changed
else:
other_firsts = firsts[symbol]
changed = update_changed(f, other_firsts) or changed
is_last = index == len(rule) - 1
if is_last and symbol in epsilons:
if is_last and epsilons[symbol]:
# If this is the last symbol and the last
# symbol can be empty then I can be empty
# too! :P
changed = add_changed(epsilons, name) or changed
changed = changed or not epsilons[name]
epsilons[name] = True
if symbol not in epsilons:
if not epsilons[symbol]:
# If we believe that there is at least one
# terminal in the first set of this
# nonterminal then I don't have to keep
@ -626,27 +644,30 @@ class FirstInfo:
@dataclasses.dataclass(frozen=True)
class FollowInfo:
follows: dict[str, set[str]]
follows: list[set[int]]
@classmethod
def from_grammar(
cls,
grammar: dict[str, list[typing.Tuple[str,...]]],
grammar: list[list[typing.Tuple[int,...]]],
terminals: typing.Tuple[bool, ...],
start_symbol: int,
end_symbol: int,
firsts: FirstInfo,
):
follows = {name: set() for name in grammar.keys()}
follows["__start"].add('$')
follows = [set() for _ in grammar]
follows[start_symbol].add(end_symbol)
changed = True
while changed:
changed = False
for name, rules in grammar.items():
for name, rules in enumerate(grammar):
for rule in rules:
epsilon = True
prev_symbol = None
for symbol in reversed(rule):
f = follows.get(symbol)
if f is None:
f = follows[symbol]
if terminals[symbol]:
# This particular rule can't produce epsilon.
epsilon = False
prev_symbol = symbol
@ -668,7 +689,7 @@ class FollowInfo:
# Now if there's no epsilon in this symbol there's no
# more epsilon in the rest of the sequence.
if symbol not in firsts.is_epsilon:
if not firsts.is_epsilon[symbol]:
epsilon = False
prev_symbol = symbol
@ -695,9 +716,15 @@ class GenerateSLR1(GenerateLR0):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._firsts = FirstInfo.from_grammar(self.grammar, self.terminals)
self._follows = FollowInfo.from_grammar(self.grammar, self._firsts)
self._follows = FollowInfo.from_grammar(
self.grammar,
self.terminals,
self.start_symbol,
self.end_symbol,
self._firsts,
)
def gen_first(self, symbols: typing.Iterable[str]) -> typing.Tuple[set[str], bool]:
def gen_first(self, symbols: typing.Iterable[int]) -> typing.Tuple[set[int], bool]:
"""Return the first set for a sequence of symbols.
Build the set by combining the first sets of the symbols from left to
@ -715,7 +742,7 @@ class GenerateSLR1(GenerateLR0):
return (result, True)
def gen_follow(self, symbol: str) -> set[str]:
def gen_follow(self, symbol: int) -> set[int]:
"""Generate the follow set for the given nonterminal.
The follow set for a nonterminal is the set of terminals that can
@ -723,10 +750,9 @@ class GenerateSLR1(GenerateLR0):
contains epsilon and is never empty, since we should always at least
ground out at '$', which is the end-of-stream marker.
"""
assert symbol in self.grammar
return self._follows.follows[symbol]
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]:
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[int]:
"""Return the set of symbols that indicate we should reduce the given
config.
@ -745,7 +771,7 @@ class GenerateLR1(GenerateSLR1):
details. (Except for the start configuration, which has '$' as its
lookahead.)
"""
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]:
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[int]:
"""Return the set of symbols that indicate we should reduce the given
config.
@ -773,7 +799,7 @@ class GenerateLR1(GenerateSLR1):
return ()
else:
next = []
for rule in self.grammar.get(config_next, ()):
for rule in self.grammar[config_next]:
lookahead, epsilon = self.gen_first(config.rest)
if epsilon:
lookahead.update(config.lookahead)
@ -789,8 +815,8 @@ class GenerateLR1(GenerateSLR1):
symbol to '$'.
"""
seeds = tuple(
Configuration.from_rule('__start', rule, lookahead=('$',))
for rule in self.grammar['__start']
Configuration.from_rule(self.start_symbol, rule, lookahead=(self.end_symbol,))
for rule in self.grammar[self.start_symbol]
)
initial_set = self.gen_closure(seeds)
return self.gen_sets(initial_set)
@ -906,14 +932,24 @@ def format_table(generator, table):
elif action[0] == 'reduce':
return 'r' + str(action[1])
terminals = [
generator.alphabet[i]
for i,v in enumerate(generator.terminals)
if v
]
nonterminals = [
generator.alphabet[i]
for i,v in enumerate(generator.nonterminals)
if v
]
header = " | {terms} | {nts}".format(
terms=' '.join(
'{0: <6}'.format(terminal)
for terminal in sorted(generator.terminals)
for terminal in sorted(terminals)
),
nts=' '.join(
'{0: <5}'.format(nt)
for nt in sorted(generator.nonterminals)
for nt in sorted(nonterminals)
),
)
@ -925,11 +961,11 @@ def format_table(generator, table):
index=i,
actions=' '.join(
'{0: <6}'.format(format_action(row, terminal))
for terminal in sorted(generator.terminals)
for terminal in sorted(terminals)
),
gotos=' '.join(
'{0: <5}'.format(row.get(nt, ('error', ''))[1])
for nt in sorted(generator.nonterminals)
for nt in sorted(nonterminals)
),
)
for i, row in enumerate(table)
@ -1001,9 +1037,9 @@ def examples():
print("grammar_lr0_shift_reduce (SLR1):")
gen = GenerateSLR1('E', grammar_lr0_shift_reduce)
first, epsilon=gen.gen_first(('E',))
first, epsilon=gen.gen_first((gen.symbol_key['E'],))
print(f"First: {str(first)} (epsilon={epsilon})")
print(f"Follow: {str(gen.gen_follow('E'))}")
print(f"Follow: {str(gen.gen_follow(gen.symbol_key['E']))}")
table = gen.gen_table()
print(format_table(gen, table))
tree = parse(table, ['id', '+', '(', 'id', '[', 'id', ']', ')'], trace=True)