From a837d662dda210456e44ad2862d16e13d8ef2727 Mon Sep 17 00:00:00 2001 From: John Doty Date: Sat, 8 Jun 2024 17:30:59 -0700 Subject: [PATCH] Error recovery: lame version of CPCT+ --- harness.py | 421 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 320 insertions(+), 101 deletions(-) diff --git a/harness.py b/harness.py index f053d77..ab0e4bc 100644 --- a/harness.py +++ b/harness.py @@ -59,56 +59,276 @@ class ParseError: end: int -@dataclass -class StopResult: - result: Tree | None - errors: list[ParseError] - score: int +ParseStack = list[typing.Tuple[int, TokenValue | Tree | None]] -@dataclass -class ContinueResult: - threads: list["ParserThread"] +RECOVER_TRACE: list[str] = [] -StepResult = StopResult | ContinueResult +def clear_recover_trace(): + RECOVER_TRACE.clear() -class ParserThread: +def recover_trace(s: str): + # RECOVER_TRACE.append(s) + del s + pass + + +class RepairAction(enum.Enum): + Base = "bas" + Insert = "ins" + Delete = "del" + Shift = "sft" + + +class RepairStack(typing.NamedTuple): + state: int + parent: "RepairStack | None" + + @classmethod + def from_stack(cls, stack: ParseStack) -> "RepairStack": + if len(stack) == 0: + raise ValueError("Empty stack") + + result: RepairStack | None = None + for item in stack: + result = RepairStack(state=item[0], parent=result) + + assert result is not None + return result + + def pop(self, n: int) -> "RepairStack": + s = self + while n > 0: + s = s.parent + n -= 1 + assert s is not None, "Stack underflow" + + return s + + def flatten(self) -> list[int]: + stack = self + result: list[int] = [] + while stack is not None: + result.append(stack.state) + stack = stack.parent + return result + + def push(self, state: int) -> "RepairStack": + return RepairStack(state, self) + + def handle_token( + self, table: parser.ParseTable, token: str + ) -> typing.Tuple["RepairStack | None", bool]: + stack = self + while True: + action = table.actions[stack.state].get(token) + if action is None: + return None, False + + match action: + case parser.Shift(): + recover_trace(f" {stack.state}: SHIFT -> {action.state}") + return stack.push(action.state), False + + case parser.Accept(): + recover_trace(f" {stack.state}: ACCEPT") + return stack, True # ? + + case parser.Reduce(): + recover_trace(f" {stack.state}: REDUCE {action.name} {action.count} ") + new_stack = stack.pop(action.count) + recover_trace(f" -> {new_stack.state}") + new_state = table.gotos[new_stack.state][action.name] + recover_trace(f" goto {new_state}") + stack = new_stack.push(new_state) + + case parser.Error(): + assert False, "Explicit error found in repair" + + case _: + typing.assert_never(action) + + +class Repair: + repair: RepairAction + cost: int + stack: RepairStack + value: str | None + parent: "Repair | None" + shifts: int + success: bool + + def __init__(self, repair, cost, stack, parent, advance=0, value=None, success=False): + self.repair = repair + self.cost = cost + self.stack = stack + self.parent = parent + self.value = value + self.success = success + self.advance = advance + + if parent is not None: + self.cost += parent.cost + self.advance += parent.advance + + if self.advance >= 3: + self.success = True + + def neighbors( + self, + table: parser.ParseTable, + input: list[TokenValue], + start: int, + ): + input_index = start + self.advance + if input_index >= len(input): + return + + valstr = f"({self.value})" if self.value is not None else "" + recover_trace(f"{self.repair.value}{valstr} @ {self.cost} input:{input_index}") + recover_trace(f" {','.join(str(s) for s in self.stack.flatten())}") + + state = self.stack.state + + # For insert: go through all the actions and run all the possible + # reduce/accepts on them. This will generate a *new stack* which we + # then capture with an "Insert" repair action. Do not manipuate the + # input stream. + # + # For shift: produce a repair that consumes the current input token, + # advancing the input stream, and manipulating the stack as + # necessary, producing a new version of the stack. Count up the + # number of successful shifts. + for token in table.actions[state].keys(): + recover_trace(f" token: {token}") + new_stack, success = self.stack.handle_token(table, token) + if new_stack is None: + # Not clear why this is necessary, but I think state merging + # causes us to occasionally have reduce actions that lead to + # errors. + continue + + if token == input[input_index].kind: + recover_trace(f" generate shift {token}") + yield Repair( + repair=RepairAction.Shift, + parent=self, + stack=new_stack, + cost=0, # Shifts are free. + advance=1, # Move forward by one. + ) + + recover_trace(f" generate insert {token}") + yield Repair( + repair=RepairAction.Insert, + value=token, + parent=self, + stack=new_stack, + cost=1, # TODO: Configurable token costs + success=success, + ) + + # For delete: produce a repair that just advances the input token + # stream, but does not manipulate the stack at all. Obviously we can + # only do this if we aren't at the end of the stream. Do not generate + # a "delete" if the previous repair was an "insert". (Only allow + # delete-insert pairs, not insert-delete, because they are + # symmetrical and therefore a waste of time and memory.) + if self.repair != RepairAction.Insert: + recover_trace(f" generate delete") + yield Repair( + repair=RepairAction.Delete, + parent=self, + stack=self.stack, + cost=3, # TODO: Configurable token costs + advance=1, + ) + + +def recover(table: parser.ParseTable, input: list[TokenValue], start: int, stack: ParseStack): + initial = Repair( + repair=RepairAction.Base, + cost=0, + stack=RepairStack.from_stack(stack), + parent=None, + ) + + todo_queue = [[initial]] + level = 0 + while level < len(todo_queue): + queue_index = 0 + queue = todo_queue[level] + while queue_index < len(queue): + repair = queue[queue_index] + + # NOTE: This is guaranteed to be the cheapest possible success- + # there can be no success cheaper than this one. Since + # we're going to pick one arbitrarily, this one might as + # well be it. + if repair.success: + repairs: list[Repair] = [] + while repair is not None: + repairs.append(repair) + repair = repair.parent + repairs.reverse() + return repairs + + for neighbor in repair.neighbors(table, input, start): + for _ in range((neighbor.cost - len(todo_queue)) + 1): + todo_queue.append([]) + todo_queue[neighbor.cost].append(neighbor) + + queue_index += 1 + level += 1 + + +class Parser: # Our stack is a stack of tuples, where the first entry is the state # number and the second entry is the 'value' that was generated when the # state was pushed. table: parser.ParseTable - stack: list[typing.Tuple[int, TokenValue | Tree | None]] - errors: list[ParseError] - score: int - def __init__(self, id, trace, table, stack): - self.id = id + def __init__(self, table, trace): self.trace = trace self.table = table - self.stack = stack - self.errors = [] - self.score = 0 - def step(self, current_token: TokenValue) -> StepResult: - stack = self.stack - table = self.table + def parse(self, tokens) -> typing.Tuple[Tree | None, list[str]]: + clear_recover_trace() + + input_tokens = tokens.tokens() + input: list[TokenValue] = [ + TokenValue(kind=kind.value, start=start, end=start + length) + for (kind, start, length) in input_tokens + ] + + eof = 0 if len(input) == 0 else input[-1].end + input = input + [TokenValue(kind="$", start=eof, end=eof)] + input_index = 0 + + stack: ParseStack = [(0, None)] + result: Tree | None = None + errors: list[ParseError] = [] while True: + current_token = input[input_index] current_state = stack[-1][0] - action = table.actions[current_state].get(current_token.kind, parser.Error()) + action = self.table.actions[current_state].get(current_token.kind, parser.Error()) if self.trace: - self.trace(self.id, stack, current_token, action) + self.trace(stack, current_token, action) match action: case parser.Accept(): - result = stack[-1][1] - assert isinstance(result, Tree) - return StopResult(result, self.errors, self.score) + r = stack[-1][1] + assert isinstance(r, Tree) + result = r + break case parser.Reduce(name=name, count=size, transparent=transparent): + recover_trace(f" {current_token.kind}") + recover_trace(f" {current_state}: REDUCE {name} {size}") children: list[TokenValue | Tree] = [] for _, c in stack[-size:]: if c is None: @@ -125,15 +345,15 @@ class ParserThread: children=tuple(children), ) del stack[-size:] - - goto = table.gotos[stack[-1][0]].get(name) + recover_trace(f" -> {stack[-1][0]}") + goto = self.table.gotos[stack[-1][0]].get(name) assert goto is not None + recover_trace(f" -> {goto}") stack.append((goto, value)) - continue - case parser.Shift(state): - stack.append((state, current_token)) - return ContinueResult([self]) + case parser.Shift(): + stack.append((action.state, current_token)) + input_index += 1 case parser.Error(): if current_token.kind == "$": @@ -141,73 +361,70 @@ class ParserThread: else: message = f"Syntax error: unexpected symbol {current_token.kind}" - self.errors.append( + errors.append( ParseError( - message=message, start=current_token.start, end=current_token.end + message=message, + start=current_token.start, + end=current_token.end, ) ) - # TODO: Error Recovery Here - return StopResult(None, self.errors, self.score) + + repairs = recover(self.table, input, input_index, stack) + + # If we were unable to find a repair sequence, then just + # quit here; we have what we have. We *should* do our + # best to generate a tree, but I'm not sure if we can? + if repairs is None: + break + + # If we were *were* able to find a repair, apply it to + # the token stream and continue moving. It is guaranteed + # that we will not generate an error until we get to the + # end of the stream that we found. + cursor = input_index + for repair in repairs: + match repair.repair: + case RepairAction.Base: + # Don't need to do anything here, this is + # where we started. + pass + + case RepairAction.Insert: + # Insert a token into the stream. + # Need to advance the cursor to compensate. + assert repair.value is not None + input.insert( + cursor, TokenValue(kind=repair.value, start=-1, end=-1) + ) + cursor += 1 + + case RepairAction.Delete: + del input[cursor] + + case RepairAction.Shift: + # Just consume the token where we are. + cursor += 1 + + case _: + typing.assert_never(repair.repair) case _: - raise ValueError(f"Unknown action type: {action}") + typing.assert_never(action) + # All done. + error_strings = [] + for parse_error in errors: + line_index = bisect.bisect_left(tokens.lines, parse_error.start) + if line_index == 0: + col_start = 0 + else: + col_start = tokens.lines[line_index - 1] + 1 + column_index = parse_error.start - col_start + line_index += 1 -def parse(table: parser.ParseTable, tokens, trace=None) -> typing.Tuple[Tree | None, list[str]]: - input_tokens = tokens.tokens() - input: list[TokenValue] = [ - TokenValue(kind=kind.value, start=start, end=start + length) - for (kind, start, length) in input_tokens - ] + error_strings.append(f"{line_index}:{column_index}: {parse_error.message}") - eof = 0 if len(input) == 0 else input[-1].end - input = input + [TokenValue(kind="$", start=eof, end=eof)] - input_index = 0 - - threads = [ - ParserThread(0, trace, table, [(0, None)]), - ] - results: list[StopResult] = [] - - while len(threads) > 0: - current_token = input[input_index] - next_threads: list[ParserThread] = [] - for thread in threads: - sr = thread.step(current_token) - match sr: - case StopResult(): - results.append(sr) - break - - case ContinueResult(threads): - assert len(threads) > 0 - next_threads.extend(threads) - break - - case _: - typing.assert_never(sr) - - # All threads have accepted or errored or consumed input. - threads = next_threads - input_index += 1 - - assert len(results) > 0 - results.sort(key=lambda x: x.score) - result = results[0] - - error_strings = [] - for parse_error in result.errors: - line_index = bisect.bisect_left(tokens.lines, parse_error.start) - if line_index == 0: - col_start = 0 - else: - col_start = tokens.lines[line_index - 1] + 1 - column_index = parse_error.start - col_start - line_index += 1 - - error_strings.append(f"{line_index}:{column_index}: {parse_error.message}") - - return (result.result, error_strings) + return (result, error_strings) ############################################################################### @@ -421,7 +638,7 @@ class Harness: # print(f"{tokens.lines}") # tokens.dump(end=5) - (tree, errors) = parse(table, self.tokens, trace=None) + (tree, errors) = Parser(table, trace=None).parse(self.tokens) parse_time = time.time() self.tree = tree self.errors = errors @@ -454,16 +671,18 @@ class Harness: print(f"No table\r") print(("\u2500" * cols) + "\r") - if self.tree is not None: - lines = [] - self.format_node(lines, self.tree) - for line in lines[: rows - 3]: - print(line[:cols] + "\r") - else: + lines = list(RECOVER_TRACE) + + if self.errors is not None: wrapper = textwrap.TextWrapper(width=cols, drop_whitespace=False) - lines = [line for error in self.errors for line in wrapper.wrap(error)] - for line in lines[: rows - 3]: - print(line + "\r") + lines.extend(line for error in self.errors for line in wrapper.wrap(error)) + lines.append("") + + if self.tree is not None: + self.format_node(lines, self.tree) + + for line in lines[: rows - 3]: + print(line[:cols] + "\r") sys.stdout.flush() sys.stdout.buffer.flush()