faster: Be explicit about Configuration, cache hash

The next step though will be to replace the Configuration with an
integer, and intern all Configurations, along with all other objects.
This commit is contained in:
John Doty 2024-04-15 09:48:03 -07:00
parent be93498e96
commit 5f89f460e5

View file

@ -15,7 +15,6 @@ import typing
# #
# We start with LR0 parsers, because they form the basis of everything else. # We start with LR0 parsers, because they form the basis of everything else.
############################################################################### ###############################################################################
@dataclasses.dataclass(frozen=True, order=True)
class Configuration: class Configuration:
"""A rule being tracked in a state. """A rule being tracked in a state.
@ -23,10 +22,39 @@ class Configuration:
but if left at its default it's harmless. Ignore it until you get to but if left at its default it's harmless. Ignore it until you get to
the part about LR(1).) the part about LR(1).)
""" """
__slots__ = (
'name',
'symbols',
'position',
'lookahead',
'next',
'at_end',
'_vals',
'_hash',
)
name: str name: str
symbols: typing.Tuple[str, ...] symbols: typing.Tuple[str, ...]
position: int position: int
lookahead: typing.Tuple[str, ...] lookahead: typing.Tuple[str, ...]
next: str | None
at_end: bool
_vals: typing.Tuple
_hash: int
def __init__(self, name, symbols, position, lookahead) -> None:
self.name = name
self.symbols = symbols
self.position = position
self.lookahead = lookahead
at_end = position == len(symbols)
self.at_end = at_end
self.next = symbols[position] if not at_end else None
self._vals = (name, symbols, position, lookahead)
self._hash = hash(self._vals)
@classmethod @classmethod
def from_rule(cls, name: str, symbols: typing.Tuple[str, ...], lookahead=()): def from_rule(cls, name: str, symbols: typing.Tuple[str, ...], lookahead=()):
@ -37,13 +65,58 @@ class Configuration:
lookahead=lookahead, lookahead=lookahead,
) )
@property def __hash__(self) -> int:
def at_end(self): return self._hash
return self.position == len(self.symbols)
@property def __eq__(self, value: object, /) -> bool:
def next(self): if value is self:
return self.symbols[self.position] if not self.at_end else None return True
if not isinstance(value, Configuration):
return NotImplemented
return (
value._hash == self._hash and
value.name == self.name and
value.position == self.position and
value.symbols == self.symbols and
value.lookahead == self.lookahead
)
def __lt__(self, value) -> bool:
if not isinstance(value, Configuration):
return NotImplemented
return self._vals < value._vals
def __gt__(self, value) -> bool:
if not isinstance(value, Configuration):
return NotImplemented
return self._vals > value._vals
def __le__(self, value) -> bool:
if not isinstance(value, Configuration):
return NotImplemented
return self._vals <= value._vals
def __ge__(self, value) -> bool:
if not isinstance(value, Configuration):
return NotImplemented
return self._vals >= value._vals
def replace_position(self, new_position):
return Configuration(
name=self.name,
symbols=self.symbols,
position=new_position,
lookahead=self.lookahead,
)
def clear_lookahead(self):
return Configuration(
name=self.name,
symbols=self.symbols,
position=self.position,
lookahead=(),
)
@property @property
def rest(self): def rest(self):
@ -52,9 +125,6 @@ class Configuration:
def at_symbol(self, symbol): def at_symbol(self, symbol):
return self.next == symbol return self.next == symbol
def replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)
def __str__(self): def __str__(self):
la = ", " + str(self.lookahead) if self.lookahead != () else "" la = ", " + str(self.lookahead) if self.lookahead != () else ""
return "{name} -> {bits}{lookahead}".format( return "{name} -> {bits}{lookahead}".format(
@ -279,7 +349,7 @@ class GenerateLR0(object):
the symbol. the symbol.
""" """
seeds = tuple( seeds = tuple(
config.replace(position=config.position + 1) config.replace_position(config.position + 1)
for config in config_set for config in config_set
if config.at_symbol(symbol) if config.at_symbol(symbol)
) )
@ -745,17 +815,17 @@ class GenerateLALR(GenerateLR1):
merged = [] merged = []
for index, a in enumerate(config_set_a): for index, a in enumerate(config_set_a):
b = config_set_b[index] b = config_set_b[index]
assert a.replace(lookahead=()) == b.replace(lookahead=()) assert a.clear_lookahead() == b.clear_lookahead()
new_lookahead = a.lookahead + b.lookahead new_lookahead = a.lookahead + b.lookahead
new_lookahead = tuple(sorted(set(new_lookahead))) new_lookahead = tuple(sorted(set(new_lookahead)))
merged.append(a.replace(lookahead=new_lookahead)) merged.append(a.clear_lookahead())
return tuple(merged) return tuple(merged)
def sets_equal(self, a, b): def sets_equal(self, a, b):
a_no_la = tuple(s.replace(lookahead=()) for s in a) a_no_la = tuple(s.clear_lookahead() for s in a)
b_no_la = tuple(s.replace(lookahead=()) for s in b) b_no_la = tuple(s.clear_lookahead() for s in b)
return a_no_la == b_no_la return a_no_la == b_no_la
def gen_sets(self, config_set): def gen_sets(self, config_set):
@ -772,7 +842,7 @@ class GenerateLALR(GenerateLR1):
pending = [config_set] pending = [config_set]
while len(pending) > 0: while len(pending) > 0:
config_set = pending.pop() config_set = pending.pop()
config_set_no_la = tuple(s.replace(lookahead=()) for s in config_set) config_set_no_la = tuple(s.clear_lookahead() for s in config_set)
existing = F.get(config_set_no_la) existing = F.get(config_set_no_la)
if existing is not None: if existing is not None:
@ -786,10 +856,13 @@ class GenerateLALR(GenerateLR1):
# starting state! # starting state!
return tuple(F.values()) return tuple(F.values())
def set_without_lookahead(self, config_set: ConfigSet) -> ConfigSet:
return tuple(sorted(set(c.clear_lookahead() for c in config_set)))
def build_set_index(self, sets: typing.Tuple[ConfigSet, ...]) -> dict[ConfigSet, int]: def build_set_index(self, sets: typing.Tuple[ConfigSet, ...]) -> dict[ConfigSet, int]:
index = {} index = {}
for s in sets: for s in sets:
s_no_la = tuple(c.replace(lookahead=()) for c in s) s_no_la = self.set_without_lookahead(s)
if s_no_la not in index: if s_no_la not in index:
index[s_no_la] = len(index) index[s_no_la] = len(index)
return index return index
@ -798,7 +871,7 @@ class GenerateLALR(GenerateLR1):
"""Find the specified set in the set of sets, and return the """Find the specified set in the set of sets, and return the
index, or None if it is not found. index, or None if it is not found.
""" """
s_no_la = tuple(c.replace(lookahead=()) for c in s) s_no_la = self.set_without_lookahead(s)
return sets.get(s_no_la) return sets.get(s_no_la)