faster: All symbols are integers

This commit is contained in:
John Doty 2024-04-15 16:05:42 -07:00
parent ee60951ffc
commit a818a4a498

View file

@ -33,11 +33,11 @@ class Configuration:
'_hash', '_hash',
) )
name: str name: int
symbols: typing.Tuple[str, ...] symbols: typing.Tuple[int, ...]
position: int position: int
lookahead: typing.Tuple[str, ...] lookahead: typing.Tuple[int, ...]
next: str | None next: int | None
at_end: bool at_end: bool
_vals: typing.Tuple _vals: typing.Tuple
@ -57,7 +57,7 @@ class Configuration:
self._hash = hash(self._vals) self._hash = hash(self._vals)
@classmethod @classmethod
def from_rule(cls, name: str, symbols: typing.Tuple[str, ...], lookahead=()): def from_rule(cls, name: int, symbols: typing.Tuple[int, ...], lookahead=()):
return Configuration( return Configuration(
name=name, name=name,
symbols=symbols, symbols=symbols,
@ -122,12 +122,12 @@ class Configuration:
def rest(self): def rest(self):
return self.symbols[(self.position+1):] return self.symbols[(self.position+1):]
def __str__(self): def format(self, alphabet: list[str]) -> str:
la = ", " + str(self.lookahead) if self.lookahead != () else "" la = ", " + str(tuple(alphabet[i] for i in self.lookahead)) if self.lookahead != () else ""
return "{name} -> {bits}{lookahead}".format( return "{name} -> {bits}{lookahead}".format(
name=self.name, name=self.name,
bits=' '.join([ bits=' '.join([
'* ' + sym if i == self.position else sym '* ' + alphabet[sym] if i == self.position else alphabet[sym]
for i, sym in enumerate(self.symbols) for i, sym in enumerate(self.symbols)
]) + (' *' if self.at_end else ''), ]) + (' *' if self.at_end else ''),
lookahead=la, lookahead=la,
@ -136,9 +136,10 @@ class Configuration:
ConfigSet = typing.Tuple[Configuration, ...] ConfigSet = typing.Tuple[Configuration, ...]
class TableBuilder(object): class TableBuilder(object):
def __init__(self): def __init__(self, alphabet: list[str]):
self.errors = [] self.errors = []
self.table = [] self.table = []
self.alphabet = alphabet
self.row = None self.row = None
def flush(self): def flush(self):
@ -149,39 +150,45 @@ class TableBuilder(object):
def new_row(self, config_set): def new_row(self, config_set):
self._flush_row() self._flush_row()
self.row = {} self.row = [(None, None) for _ in self.alphabet]
self.current_config_set = config_set self.current_config_set = config_set
def _flush_row(self): def _flush_row(self):
if self.row: if self.row:
actions = {k: v[0] for k,v in self.row.items()} actions = {
self.alphabet[k]: v[0]
for k, v in enumerate(self.row)
if v[0] is not None
}
self.table.append(actions) self.table.append(actions)
def set_table_reduce(self, symbol, config): def set_table_reduce(self, symbol: int, config):
action = ('reduce', config.name, len(config.symbols)) action = ('reduce', self.alphabet[config.name], len(config.symbols))
self._set_table_action(symbol, action, config) self._set_table_action(symbol, action, config)
def set_table_accept(self, config): def set_table_accept(self, symbol: int, config: Configuration):
action = ('accept',) action = ('accept',)
self._set_table_action('$', action, config) self._set_table_action(symbol, action, config)
def set_table_shift(self, index, config): def set_table_shift(self, symbol: int, index: int, config: Configuration):
action = ('shift', index) action = ('shift', index)
self._set_table_action(config.next, action, config) self._set_table_action(symbol, action, config)
def set_table_goto(self, symbol, index): def set_table_goto(self, symbol: int, index: int):
action = ('goto', index) action = ('goto', index)
self._set_table_action(symbol, action, None) self._set_table_action(symbol, action, None)
def _set_table_action(self, symbol, action, config): def _set_table_action(self, symbol_id: int, action, config):
"""Set the action for 'symbol' in the table row to 'action'. """Set the action for 'symbol' in the table row to 'action'.
This is destructive; it changes the table. It raises an error if This is destructive; it changes the table. It raises an error if
there is already an action for the symbol in the row. there is already an action for the symbol in the row.
""" """
assert isinstance(symbol_id, int)
assert self.row is not None assert self.row is not None
existing, existing_config = self.row.get(symbol, (None, None)) existing, existing_config = self.row[symbol_id]
if existing is not None and existing != action: if existing is not None and existing != action:
config_old = str(existing_config) config_old = str(existing_config)
config_new = str(config) config_new = str(config)
@ -195,11 +202,11 @@ class TableBuilder(object):
max_len=max_len, max_len=max_len,
old=existing, old=existing,
new=action, new=action,
symbol=symbol, symbol=self.alphabet[symbol_id],
) )
) )
self.errors.append(error) self.errors.append(error)
self.row[symbol] = (action, config) self.row[symbol_id] = (action, config)
class GenerateLR0(object): class GenerateLR0(object):
@ -244,42 +251,29 @@ class GenerateLR0(object):
everywhere.) everywhere.)
""" """
grammar: dict[str, list[typing.Tuple[str, ...]]]
nonterminals: set[str]
terminals: set[str]
alphabet: list[str] alphabet: list[str]
grammar: list[list[typing.Tuple[int, ...]]]
nonterminals: typing.Tuple[bool, ...]
terminals: typing.Tuple[bool, ...]
symbol_key: dict[str, int]
start_symbol: int
end_symbol: int
def __init__(self, start: str, grammar: list[typing.Tuple[str, list[str]]]): def __init__(self, start: str, grammar: list[typing.Tuple[str, list[str]]]):
"""Initialize the parser generator with the specified grammar and """Initialize the parser generator with the specified grammar and
start symbol. start symbol.
""" """
# Turn the incoming grammar into a dictionary, indexed by nonterminal. # Work out the alphabet.
# alphabet = set()
# We count on python dictionaries retaining the insertion order, like
# it or not.
full_grammar = {}
for name, rule in grammar: for name, rule in grammar:
rules = full_grammar.get(name) alphabet.add(name)
if rules is None: alphabet.update(symbol for symbol in rule)
rules = []
full_grammar[name] = rules
rules.append(tuple(rule))
self.grammar = full_grammar
self.nonterminals = set(self.grammar.keys())
self.terminals = {
sym
for _, symbols in grammar
for sym in symbols
if sym not in self.nonterminals
}
self.alphabet = list(sorted(self.terminals | self.nonterminals))
# Check to make sure they didn't use anything that will give us # Check to make sure they didn't use anything that will give us
# heartburn later. # heartburn later.
reserved = [a for a in self.alphabet if a.startswith('__') or a == '$'] reserved = [a for a in alphabet if a.startswith('__') or a == '$']
if reserved: if reserved:
raise ValueError( raise ValueError(
"Can't use {symbols} in grammars, {what} reserved.".format( "Can't use {symbols} in grammars, {what} reserved.".format(
@ -288,10 +282,54 @@ class GenerateLR0(object):
) )
) )
alphabet.add('__start')
alphabet.add('$')
self.alphabet = list(sorted(alphabet))
self.grammar['__start'] = [(start,)] symbol_key = {
self.terminals.add('$') symbol: index
self.alphabet.append('$') for index, symbol in enumerate(self.alphabet)
}
start_symbol = symbol_key['__start']
end_symbol = symbol_key['$']
assert self.alphabet[start_symbol] == '__start'
assert self.alphabet[end_symbol] == '$'
# Turn the incoming grammar into a dictionary, indexed by nonterminal.
#
# We count on python dictionaries retaining the insertion order, like
# it or not.
full_grammar = [list() for _ in self.alphabet]
terminals = [True for _ in self.alphabet]
assert terminals[end_symbol]
nonterminals = [False for _ in self.alphabet]
for name, rule in grammar:
name_symbol = symbol_key[name]
terminals[name_symbol] = False
nonterminals[name_symbol] = True
rules = full_grammar[name_symbol]
rules.append(tuple(symbol_key[symbol] for symbol in rule))
self.grammar = full_grammar
self.grammar[start_symbol].append((symbol_key[start],))
terminals[start_symbol] = False
nonterminals[start_symbol] = True
self.terminals = tuple(terminals)
self.nonterminals = tuple(nonterminals)
assert self.terminals[end_symbol]
assert self.nonterminals[start_symbol]
self.symbol_key = symbol_key
self.start_symbol = start_symbol
self.end_symbol = end_symbol
@functools.cache @functools.cache
@ -311,7 +349,7 @@ class GenerateLR0(object):
else: else:
return tuple( return tuple(
Configuration.from_rule(next, rule) Configuration.from_rule(next, rule)
for rule in self.grammar.get(next, ()) for rule in self.grammar[next]
) )
@functools.cache @functools.cache
@ -393,8 +431,8 @@ class GenerateLR0(object):
def gen_all_sets(self) -> typing.Tuple[ConfigSet, ...]: def gen_all_sets(self) -> typing.Tuple[ConfigSet, ...]:
"""Generate all of the configuration sets for the grammar.""" """Generate all of the configuration sets for the grammar."""
seeds = tuple( seeds = tuple(
Configuration.from_rule('__start', rule) Configuration.from_rule(self.start_symbol, rule)
for rule in self.grammar['__start'] for rule in self.grammar[self.start_symbol]
) )
initial_set = self.gen_closure(seeds) initial_set = self.gen_closure(seeds)
return self.gen_sets(initial_set) return self.gen_sets(initial_set)
@ -408,13 +446,13 @@ class GenerateLR0(object):
""" """
return sets.get(s) return sets.get(s)
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]: def gen_reduce_set(self, config: Configuration) -> typing.Iterable[int]:
"""Return the set of symbols that indicate we should reduce the given """Return the set of symbols that indicate we should reduce the given
configuration. configuration.
In an LR0 parser, this is just the set of all terminals.""" In an LR0 parser, this is just the set of all terminals."""
del(config) del(config)
return self.terminals return [index for index, value in enumerate(self.terminals) if value]
def gen_table(self): def gen_table(self):
"""Generate the parse table. """Generate the parse table.
@ -444,7 +482,7 @@ class GenerateLR0(object):
Anything missing from the row indicates an error. Anything missing from the row indicates an error.
""" """
builder = TableBuilder() builder = TableBuilder(self.alphabet)
config_sets = self.gen_all_sets() config_sets = self.gen_all_sets()
set_index = self.build_set_index(config_sets) set_index = self.build_set_index(config_sets)
@ -454,56 +492,31 @@ class GenerateLR0(object):
# Actions # Actions
for config in config_set: for config in config_set:
if config.at_end: config_next = config.next
if config.name != '__start': if config_next is None:
if config.name != self.start_symbol:
for a in self.gen_reduce_set(config): for a in self.gen_reduce_set(config):
builder.set_table_reduce(a, config) builder.set_table_reduce(a, config)
else: else:
builder.set_table_accept(config) builder.set_table_accept(self.end_symbol, config)
else: elif self.terminals[config_next]:
if config.next in self.terminals: successor = self.gen_successor(config_set, config_next)
successor = self.gen_successor(config_set, config.next) index = self.find_set_index(set_index, successor)
index = self.find_set_index(set_index, successor) assert index is not None
assert index is not None builder.set_table_shift(config_next, index, config)
builder.set_table_shift(index, config)
# Gotos # Gotos
for symbol in self.nonterminals: for symbol, is_nonterminal in enumerate(self.nonterminals):
successor = self.gen_successor(config_set, symbol) if is_nonterminal:
index = self.find_set_index(set_index, successor) successor = self.gen_successor(config_set, symbol)
if index is not None: index = self.find_set_index(set_index, successor)
builder.set_table_goto(symbol, index) if index is not None:
builder.set_table_goto(symbol, index)
return builder.flush() return builder.flush()
def set_table_action(self, errors, row, symbol, action, config):
"""Set the action for 'symbol' in the table row to 'action'.
This is destructive; it changes the table. It raises an error if
there is already an action for the symbol in the row.
"""
existing, existing_config = row.get(symbol, (None, None))
if existing is not None and existing != action:
config_old = str(existing_config)
config_new = str(config)
max_len = max(len(config_old), len(config_new)) + 1
error = (
"Conflicting actions for token '{symbol}':\n"
" {config_old: <{max_len}}: {old}\n"
" {config_new: <{max_len}}: {new}\n".format(
config_old=config_old,
config_new=config_new,
max_len=max_len,
old=existing,
new=action,
symbol=symbol,
)
)
errors.append(error)
row[symbol] = (action, config)
def parse(table, input, trace=False): def parse(table, input, trace=False):
"""Parse the input with the generated parsing table and return the """Parse the input with the generated parsing table and return the
@ -577,45 +590,50 @@ def update_changed(items: set, other: set) -> bool:
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class FirstInfo: class FirstInfo:
firsts: dict[str, set[str]] firsts: list[set[int]]
is_epsilon: set[str] is_epsilon: list[bool]
@classmethod @classmethod
def from_grammar( def from_grammar(
cls, cls,
grammar: dict[str, list[typing.Tuple[str,...]]], grammar: list[list[typing.Tuple[int,...]]],
terminals: set[str], terminals: typing.Tuple[bool, ...],
): ):
firsts = {name: set() for name in grammar.keys()} firsts = [set() for _ in grammar]
for t in terminals:
firsts[t] = {t}
epsilons = set() # Add all terminals to their own firsts
for index, is_terminal in enumerate(terminals):
if is_terminal:
firsts[index].add(index)
epsilons = [False] * len(grammar)
changed = True changed = True
while changed: while changed:
changed = False changed = False
for name, rules in grammar.items(): for name, rules in enumerate(grammar):
f = firsts[name] f = firsts[name]
for rule in rules: for rule in rules:
if len(rule) == 0: if len(rule) == 0:
changed = add_changed(epsilons, name) or changed changed = changed or not epsilons[name]
epsilons[name] = True
continue continue
for index, symbol in enumerate(rule): for index, symbol in enumerate(rule):
if symbol in terminals: if terminals[symbol]:
changed = add_changed(f, symbol) or changed changed = add_changed(f, symbol) or changed
else: else:
other_firsts = firsts[symbol] other_firsts = firsts[symbol]
changed = update_changed(f, other_firsts) or changed changed = update_changed(f, other_firsts) or changed
is_last = index == len(rule) - 1 is_last = index == len(rule) - 1
if is_last and symbol in epsilons: if is_last and epsilons[symbol]:
# If this is the last symbol and the last # If this is the last symbol and the last
# symbol can be empty then I can be empty # symbol can be empty then I can be empty
# too! :P # too! :P
changed = add_changed(epsilons, name) or changed changed = changed or not epsilons[name]
epsilons[name] = True
if symbol not in epsilons: if not epsilons[symbol]:
# If we believe that there is at least one # If we believe that there is at least one
# terminal in the first set of this # terminal in the first set of this
# nonterminal then I don't have to keep # nonterminal then I don't have to keep
@ -626,27 +644,30 @@ class FirstInfo:
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class FollowInfo: class FollowInfo:
follows: dict[str, set[str]] follows: list[set[int]]
@classmethod @classmethod
def from_grammar( def from_grammar(
cls, cls,
grammar: dict[str, list[typing.Tuple[str,...]]], grammar: list[list[typing.Tuple[int,...]]],
terminals: typing.Tuple[bool, ...],
start_symbol: int,
end_symbol: int,
firsts: FirstInfo, firsts: FirstInfo,
): ):
follows = {name: set() for name in grammar.keys()} follows = [set() for _ in grammar]
follows["__start"].add('$') follows[start_symbol].add(end_symbol)
changed = True changed = True
while changed: while changed:
changed = False changed = False
for name, rules in grammar.items(): for name, rules in enumerate(grammar):
for rule in rules: for rule in rules:
epsilon = True epsilon = True
prev_symbol = None prev_symbol = None
for symbol in reversed(rule): for symbol in reversed(rule):
f = follows.get(symbol) f = follows[symbol]
if f is None: if terminals[symbol]:
# This particular rule can't produce epsilon. # This particular rule can't produce epsilon.
epsilon = False epsilon = False
prev_symbol = symbol prev_symbol = symbol
@ -668,7 +689,7 @@ class FollowInfo:
# Now if there's no epsilon in this symbol there's no # Now if there's no epsilon in this symbol there's no
# more epsilon in the rest of the sequence. # more epsilon in the rest of the sequence.
if symbol not in firsts.is_epsilon: if not firsts.is_epsilon[symbol]:
epsilon = False epsilon = False
prev_symbol = symbol prev_symbol = symbol
@ -695,9 +716,15 @@ class GenerateSLR1(GenerateLR0):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._firsts = FirstInfo.from_grammar(self.grammar, self.terminals) self._firsts = FirstInfo.from_grammar(self.grammar, self.terminals)
self._follows = FollowInfo.from_grammar(self.grammar, self._firsts) self._follows = FollowInfo.from_grammar(
self.grammar,
self.terminals,
self.start_symbol,
self.end_symbol,
self._firsts,
)
def gen_first(self, symbols: typing.Iterable[str]) -> typing.Tuple[set[str], bool]: def gen_first(self, symbols: typing.Iterable[int]) -> typing.Tuple[set[int], bool]:
"""Return the first set for a sequence of symbols. """Return the first set for a sequence of symbols.
Build the set by combining the first sets of the symbols from left to Build the set by combining the first sets of the symbols from left to
@ -715,7 +742,7 @@ class GenerateSLR1(GenerateLR0):
return (result, True) return (result, True)
def gen_follow(self, symbol: str) -> set[str]: def gen_follow(self, symbol: int) -> set[int]:
"""Generate the follow set for the given nonterminal. """Generate the follow set for the given nonterminal.
The follow set for a nonterminal is the set of terminals that can The follow set for a nonterminal is the set of terminals that can
@ -723,10 +750,9 @@ class GenerateSLR1(GenerateLR0):
contains epsilon and is never empty, since we should always at least contains epsilon and is never empty, since we should always at least
ground out at '$', which is the end-of-stream marker. ground out at '$', which is the end-of-stream marker.
""" """
assert symbol in self.grammar
return self._follows.follows[symbol] return self._follows.follows[symbol]
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]: def gen_reduce_set(self, config: Configuration) -> typing.Iterable[int]:
"""Return the set of symbols that indicate we should reduce the given """Return the set of symbols that indicate we should reduce the given
config. config.
@ -745,7 +771,7 @@ class GenerateLR1(GenerateSLR1):
details. (Except for the start configuration, which has '$' as its details. (Except for the start configuration, which has '$' as its
lookahead.) lookahead.)
""" """
def gen_reduce_set(self, config: Configuration) -> typing.Iterable[str]: def gen_reduce_set(self, config: Configuration) -> typing.Iterable[int]:
"""Return the set of symbols that indicate we should reduce the given """Return the set of symbols that indicate we should reduce the given
config. config.
@ -773,7 +799,7 @@ class GenerateLR1(GenerateSLR1):
return () return ()
else: else:
next = [] next = []
for rule in self.grammar.get(config_next, ()): for rule in self.grammar[config_next]:
lookahead, epsilon = self.gen_first(config.rest) lookahead, epsilon = self.gen_first(config.rest)
if epsilon: if epsilon:
lookahead.update(config.lookahead) lookahead.update(config.lookahead)
@ -789,8 +815,8 @@ class GenerateLR1(GenerateSLR1):
symbol to '$'. symbol to '$'.
""" """
seeds = tuple( seeds = tuple(
Configuration.from_rule('__start', rule, lookahead=('$',)) Configuration.from_rule(self.start_symbol, rule, lookahead=(self.end_symbol,))
for rule in self.grammar['__start'] for rule in self.grammar[self.start_symbol]
) )
initial_set = self.gen_closure(seeds) initial_set = self.gen_closure(seeds)
return self.gen_sets(initial_set) return self.gen_sets(initial_set)
@ -906,14 +932,24 @@ def format_table(generator, table):
elif action[0] == 'reduce': elif action[0] == 'reduce':
return 'r' + str(action[1]) return 'r' + str(action[1])
terminals = [
generator.alphabet[i]
for i,v in enumerate(generator.terminals)
if v
]
nonterminals = [
generator.alphabet[i]
for i,v in enumerate(generator.nonterminals)
if v
]
header = " | {terms} | {nts}".format( header = " | {terms} | {nts}".format(
terms=' '.join( terms=' '.join(
'{0: <6}'.format(terminal) '{0: <6}'.format(terminal)
for terminal in sorted(generator.terminals) for terminal in sorted(terminals)
), ),
nts=' '.join( nts=' '.join(
'{0: <5}'.format(nt) '{0: <5}'.format(nt)
for nt in sorted(generator.nonterminals) for nt in sorted(nonterminals)
), ),
) )
@ -925,11 +961,11 @@ def format_table(generator, table):
index=i, index=i,
actions=' '.join( actions=' '.join(
'{0: <6}'.format(format_action(row, terminal)) '{0: <6}'.format(format_action(row, terminal))
for terminal in sorted(generator.terminals) for terminal in sorted(terminals)
), ),
gotos=' '.join( gotos=' '.join(
'{0: <5}'.format(row.get(nt, ('error', ''))[1]) '{0: <5}'.format(row.get(nt, ('error', ''))[1])
for nt in sorted(generator.nonterminals) for nt in sorted(nonterminals)
), ),
) )
for i, row in enumerate(table) for i, row in enumerate(table)
@ -1001,9 +1037,9 @@ def examples():
print("grammar_lr0_shift_reduce (SLR1):") print("grammar_lr0_shift_reduce (SLR1):")
gen = GenerateSLR1('E', grammar_lr0_shift_reduce) gen = GenerateSLR1('E', grammar_lr0_shift_reduce)
first, epsilon=gen.gen_first(('E',)) first, epsilon=gen.gen_first((gen.symbol_key['E'],))
print(f"First: {str(first)} (epsilon={epsilon})") print(f"First: {str(first)} (epsilon={epsilon})")
print(f"Follow: {str(gen.gen_follow('E'))}") print(f"Follow: {str(gen.gen_follow(gen.symbol_key['E']))}")
table = gen.gen_table() table = gen.gen_table()
print(format_table(gen, table)) print(format_table(gen, table))
tree = parse(table, ['id', '+', '(', 'id', '[', 'id', ']', ')'], trace=True) tree = parse(table, ['id', '+', '(', 'id', '[', 'id', ']', ')'], trace=True)