diff --git a/parser/wadler.py b/parser/wadler.py index 5100241..db4c66b 100644 --- a/parser/wadler.py +++ b/parser/wadler.py @@ -336,6 +336,30 @@ def child_to_name(child: runtime.Tree | runtime.TokenValue) -> str: return f"token_{child.kind}" +def slice_pre_post_trivia( + trivia_mode: dict[str, parser.TriviaMode], + trivia_tokens: list[runtime.TokenValue], +) -> tuple[ + list[tuple[parser.TriviaMode, runtime.TokenValue]], + list[tuple[parser.TriviaMode, runtime.TokenValue]], +]: + tokens = [ + (trivia_mode.get(token.kind, parser.TriviaMode.Blank), token) for token in trivia_tokens + ] + + for index, (mode, token) in enumerate(tokens): + if token.start == 0: + # Everything is pre-trivia if we're at the start of the file. + return (tokens, []) + + if mode == parser.TriviaMode.NewLine: + # This is the first newline; it belongs with the pre-trivia. + return (tokens[index:], tokens[:index]) + + # If we never found a new line then it's all post-trivia. + return ([], tokens) + + @dataclasses.dataclass class Matcher: table: parser.ParseTable @@ -438,29 +462,10 @@ class Matcher: case parser.Error(): raise Exception("How did I get a parse error here??") - def slice_pre_post_trivia(self, trivia_tokens: list[runtime.TokenValue], src: str) -> tuple[ - list[tuple[parser.TriviaMode, runtime.TokenValue]], - list[tuple[parser.TriviaMode, runtime.TokenValue]], - ]: - tokens = [ - (self.trivia_mode.get(token.kind, parser.TriviaMode.Blank), token) - for token in trivia_tokens - ] - - for index, (mode, token) in enumerate(tokens): - if token.start == 0: - # Everything is pre-trivia if we're at the start of the file. - return (tokens, []) - - if mode == parser.TriviaMode.NewLine: - # This is the first newline; it belongs with the post-trivia. - return (tokens[index + 1 :], tokens[: index + 1]) - - # If we never found a new line then it's all post-trivia. - return ([], tokens) - def apply_pre_trivia(self, trivia_tokens: list[runtime.TokenValue], src: str) -> Document: - pre_trivia, _ = self.slice_pre_post_trivia(trivia_tokens, src) + pre_trivia, _ = slice_pre_post_trivia(self.trivia_mode, trivia_tokens) + # print(f"PRE:\n{pre_trivia}") + if len(pre_trivia) == 0: return None @@ -469,6 +474,7 @@ class Matcher: trivia_doc = None new_line_count = 0 for mode, token in pre_trivia: + # print(f"PRE {mode:25} {token.kind:30} ({new_line_count})") match mode: case parser.TriviaMode.LineComment: trivia_doc = cons( @@ -488,7 +494,6 @@ class Matcher: trivia_doc = cons( trivia_doc, ForceBreak(False), - ForceBreak(False), ) case _: @@ -497,12 +502,14 @@ class Matcher: return trivia_doc def apply_post_trivia(self, trivia_tokens: list[runtime.TokenValue], src: str) -> Document: - _, post_trivia = self.slice_pre_post_trivia(trivia_tokens, src) - if len(post_trivia) == 0: - return None + if len(trivia_tokens) > 0 and trivia_tokens[-1].end == len(src): + return self.apply_eof_trivia(trivia_tokens, src) + + _, post_trivia = slice_pre_post_trivia(self.trivia_mode, trivia_tokens) trivia_doc = None for mode, token in post_trivia: + # print(f"POST {mode:25} {token.kind:30}") match mode: case parser.TriviaMode.Blank: pass @@ -525,11 +532,42 @@ class Matcher: case _: typing.assert_never(mode) - if len(trivia_tokens) > 0 and trivia_tokens[-1].end == len(src): - # As a special case, if we're post trivia at the end of the file - # then we also need to be pre-trivia too, for the hypthetical EOF - # token that we never see. - trivia_doc = cons(trivia_doc, self.apply_pre_trivia(trivia_tokens, src)) + return trivia_doc + + def apply_eof_trivia(self, trivia_tokens: list[runtime.TokenValue], src: str) -> Document: + # EOF trivia has weird rules, namely, it's like pre and post joined together but. + tokens = [ + (self.trivia_mode.get(token.kind, parser.TriviaMode.Blank), token) + for token in trivia_tokens + ] + + at_start = True + newline_count = 0 + trivia_doc = None + for mode, token in tokens: + match mode: + case parser.TriviaMode.Blank: + pass + + case parser.TriviaMode.NewLine: + at_start = False + newline_count += 1 + if newline_count <= 2: + trivia_doc = cons(trivia_doc, ForceBreak(False)) + + case parser.TriviaMode.LineComment: + # Because this is post-trivia, we know there's something + # to our left, and we can force the space. + trivia_doc = cons( + trivia_doc, + Literal(" ") if at_start else None, + Literal(src[token.start : token.end]), + ) + newline_count = 0 + at_start = False + + case _: + typing.assert_never(mode) return trivia_doc diff --git a/tests/test_wadler.py b/tests/test_wadler.py index ff52ef2..8b43496 100644 --- a/tests/test_wadler.py +++ b/tests/test_wadler.py @@ -2,6 +2,7 @@ import typing from parser.parser import ( Grammar, + ParseTable, Re, Terminal, rule, @@ -117,7 +118,7 @@ def flatten_document(doc: wadler.Document, src: str) -> list: case wadler.NewLine(replace): return [f""] case wadler.ForceBreak(): - return [""] + return [f""] case wadler.Indent(): return [[f"", flatten_document(doc.doc, src)]] case wadler.Literal(text): @@ -204,6 +205,10 @@ def test_convert_tree_to_document(): ] +def _output(txt: str) -> str: + return txt.strip().replace("*SPACE*", " ").replace("*NEWLINE*", "\n") + + def test_layout_basic(): text = '{"a": true, "b":[1,2,3], "c":[1,2,3,4,5,6,7]}' tokens = runtime.GenericTokenStream(text, JSON_LEXER) @@ -214,15 +219,14 @@ def test_layout_basic(): printer = wadler.Printer(JSON) result = printer.format_tree(tree, text, 50).apply_to_source(text) - assert ( - result - == """ + assert result == _output( + """ { "a": true, "b": [1, 2, 3], "c": [1, 2, 3, 4, 5, 6, 7] } -""".strip() +""" ) @@ -277,9 +281,8 @@ def test_forced_break(): printer = wadler.Printer(g) result = printer.format_tree(tree, text, 200).apply_to_source(text) - assert ( - result - == """ + assert result == _output( + """ ( (ok ok) ( @@ -290,5 +293,159 @@ def test_forced_break(): ) (ok ok ok ok) ) - """.strip() + """ ) + + +def test_maintaining_line_breaks(): + g = TG() + g_lexer = g.compile_lexer() + g_parser = runtime.Parser(g.build_table()) + + text = """((ok ok) +; Don't break here. +(ok) + +; ^ Do keep this break though. +(ok) + + + +; ^ This should only be one break. +(ok))""" + + tree, errors = g_parser.parse(runtime.GenericTokenStream(text, g_lexer)) + assert errors == [] + assert tree is not None + + printer = wadler.Printer(g) + result = printer.format_tree(tree, text, 200).apply_to_source(text) + + assert result == _output( + """ +( + (ok ok) + ; Don't break here. + (ok) +*SPACE* + ; ^ Do keep this break though. + (ok) +*SPACE* + ; ^ This should only be one break. + (ok) +) + """ + ) + + +def test_trailing_trivia(): + g = TG() + g_lexer = g.compile_lexer() + g_parser = runtime.Parser(g.build_table()) + + text = """((ok ok)); Don't lose this! + +; Or this! + """ + + tree, errors = g_parser.parse(runtime.GenericTokenStream(text, g_lexer)) + assert errors == [] + assert tree is not None + + printer = wadler.Printer(g) + result = printer.format_tree(tree, text, 200).apply_to_source(text) + + assert result == _output( + """ +((ok ok)) ; Don't lose this! + +; Or this!*NEWLINE* +""" + ) + + +def test_trailing_trivia_two(): + g = TG() + g_lexer = g.compile_lexer() + g_parser = runtime.Parser(g.build_table()) + + text = """((ok ok)) + +; Or this! + """ + + tree, errors = g_parser.parse(runtime.GenericTokenStream(text, g_lexer)) + assert errors == [] + assert tree is not None + + printer = wadler.Printer(g) + result = printer.format_tree(tree, text, 200).apply_to_source(text) + + assert result == _output( + """ +((ok ok)) + +; Or this!*NEWLINE* +""" + ) + + +def test_trailing_trivia_split(): + g = TG() + g_lexer = g.compile_lexer() + g_parser = runtime.Parser(g.build_table()) + + text = """((ok ok)); Don't lose this! + +; Or this! + """ + + tree, errors = g_parser.parse(runtime.GenericTokenStream(text, g_lexer)) + assert errors == [] + assert tree is not None + + def rightmost(t: runtime.Tree | runtime.TokenValue) -> runtime.TokenValue | None: + if isinstance(t, runtime.TokenValue): + return t + + for child in reversed(t.children): + result = rightmost(child) + if result is not None: + return result + + return None + + token = rightmost(tree) + assert token is not None + + TRIVIA_MODES = { + "BLANKS": TriviaMode.Blank, + "LINE_BREAK": TriviaMode.NewLine, + "COMMENT": TriviaMode.LineComment, + } + + pre_trivia, post_trivia = wadler.slice_pre_post_trivia(TRIVIA_MODES, token.post_trivia) + for mode, t in pre_trivia: + print(f"{mode:25} {t.kind:10} {repr(text[t.start:t.end])}") + print("-----") + for mode, t in post_trivia: + print(f"{mode:25} {t.kind:10} {repr(text[t.start:t.end])}") + + trivia_doc = wadler.Matcher( + ParseTable([], [], set()), + {}, + {}, + TRIVIA_MODES, + ).apply_post_trivia( + token.post_trivia, + text, + ) + + assert flatten_document(trivia_doc, text) == [ + " ", + "; Don't lose this!", + "", + "", + "; Or this!", + "", + ]