diff --git a/examples/sql.py b/examples/sql.py index a28adcc..a5afb9f 100644 --- a/examples/sql.py +++ b/examples/sql.py @@ -315,29 +315,29 @@ def alter_table_stmt(): return ( ALTER + TABLE - + opt(name + DOT) - + name + + opt(schema_name + DOT) + + table_name + alt( - RENAME + alt((TO + name), (COLUMN + name + TO + name)), + RENAME + alt((TO + table_name), (COLUMN + column_name + TO + column_name)), (ADD + opt(COLUMN) + column_def), - (DROP + opt(COLUMN) + name), + (DROP + opt(COLUMN) + column_name), ) ) @rule def analyze_stmt(): - return ANALYZE + opt(opt(name + DOT) + name) + return ANALYZE + opt(alt(schema_name, opt(schema_name + DOT) + table_or_index_name)) @rule def attach_stmt(): - return ATTACH + opt(DATABASE) + expr + AS + name + return ATTACH + opt(DATABASE) + expr + AS + schema_name @rule def begin_stmt(): - return BEGIN + opt(DEFERRED | IMMEDIATE | EXCLUSIVE) + opt(TRANSACTION + opt(name)) + return BEGIN + opt(DEFERRED | IMMEDIATE | EXCLUSIVE) + opt(TRANSACTION + opt(transaction_name)) @rule @@ -347,17 +347,17 @@ def commit_stmt(): @rule def rollback_stmt(): - return ROLLBACK + opt(TRANSACTION) + opt(TO + opt(SAVEPOINT) + name) + return ROLLBACK + opt(TRANSACTION) + opt(TO + opt(SAVEPOINT) + savepoint_name) @rule def savepoint_stmt(): - return SAVEPOINT + name + return SAVEPOINT + savepoint_name @rule def release_stmt(): - return RELEASE + opt(SAVEPOINT) + name + return RELEASE + opt(SAVEPOINT) + savepoint_name def comma_list(*rules: Rule) -> Rule: @@ -373,10 +373,10 @@ def create_index_stmt(): opt(UNIQUE), INDEX, opt(IF + NOT + EXISTS), - opt(name + DOT), - name, + opt(schema_name + DOT), + index_name, ON, - name, + table_name, LPAREN, comma_list(indexed_column), RPAREN, @@ -386,7 +386,7 @@ def create_index_stmt(): @rule def indexed_column(): - return (name | expr) + opt(COLLATE + name) + opt(asc_desc) + return (column_name | expr) + opt(COLLATE + collation_name) + opt(asc_desc) @rule @@ -396,8 +396,8 @@ def create_table_stmt(): opt(TEMP | TEMPORARY), TABLE, opt(IF, NOT, EXISTS), - opt(name, DOT), - name, + opt(schema_name, DOT), + table_name, alt( seq( LPAREN, @@ -413,7 +413,7 @@ def create_table_stmt(): @rule def column_def(): - return name + opt(type_name) + zero_or_more(column_constraint) + return column_name + opt(type_name) + zero_or_more(column_constraint) @rule @@ -429,7 +429,7 @@ def column_constraint(): seq(PRIMARY, KEY, opt(asc_desc), opt(conflict_clause), opt(AUTOINCREMENT)), seq(opt(NOT), (NULL | UNIQUE), opt(conflict_clause)), seq(DEFAULT, signed_number | literal_value | seq(LPAREN, expr, RPAREN)), - seq(COLLATE, name), + seq(COLLATE, collation_name), foreign_key_clause, seq(opt(GENERATED, ALWAYS), AS, LPAREN, expr, RPAREN, opt(STORED | VIRTUAL)), ), @@ -458,7 +458,7 @@ def table_constraint(): FOREIGN, KEY, LPAREN, - comma_list(name), + comma_list(column_name), RPAREN, foreign_key_clause, ), @@ -470,8 +470,8 @@ def table_constraint(): def foreign_key_clause(): return seq( REFERENCES, - name, - opt(LPAREN, comma_list(name), RPAREN), + foreign_table, + opt(LPAREN, comma_list(column_name), RPAREN), zero_or_more( alt( seq( @@ -503,12 +503,12 @@ def create_trigger_stmt(): opt(TEMP | TEMPORARY), TRIGGER, opt(IF, NOT, EXISTS), - opt(name, DOT), - name, + opt(schema_name, DOT), + trigger_name, opt(BEFORE | AFTER | (INSTEAD + OF)), - (DELETE | INSERT | (UPDATE + opt(OF, comma_list(name)))), + (DELETE | INSERT | (UPDATE + opt(OF, comma_list(column_name)))), ON, - name, + table_name, opt(FOR, EACH, ROW), opt(WHEN, expr), BEGIN, @@ -524,9 +524,9 @@ def create_view_stmt(): opt(TEMP | TEMPORARY), VIEW, opt(IF, NOT, EXISTS), - opt(name, DOT), - name, - opt(LPAREN, comma_list(name), RPAREN), + opt(schema_name, DOT), + view_name, + opt(LPAREN, comma_list(column_name), RPAREN), AS, select_stmt, ) @@ -539,10 +539,10 @@ def create_virtual_table_stmt(): VIRTUAL, TABLE, opt(IF, NOT, EXISTS), - opt(name, DOT), - name, + opt(schema_name, DOT), + table_name, USING, - name, + module_name, opt(LPAREN, comma_list(module_argument), RPAREN), ) @@ -558,7 +558,7 @@ def with_clause(): @rule def cte_table_name(): - return name + opt(LPAREN, comma_list(name), RPAREN) + return table_name + opt(LPAREN, comma_list(column_name), RPAREN) @rule @@ -578,8 +578,8 @@ def recursive_cte(): @rule def common_table_expression(): return seq( - name, - opt(LPAREN, comma_list(name), RPAREN), + table_name, + opt(LPAREN, comma_list(column_name), RPAREN), AS, LPAREN, select_stmt, @@ -617,7 +617,7 @@ def delete_stmt_limited(): @rule def detach_stmt(): - return DETACH + opt(DATABASE) + name + return DETACH + opt(DATABASE) + schema_name @rule @@ -626,8 +626,8 @@ def drop_stmt(): DROP, (INDEX | TABLE | TRIGGER | VIEW), opt(IF, EXISTS), - opt(name, DOT), - name, + opt(schema_name, DOT), + any_name, ) @@ -647,7 +647,7 @@ def expr(): return alt( literal_value, BIND_PARAMETER, - opt(opt(name, DOT), name, DOT) + name, + opt(opt(schema_name, DOT), table_name, DOT) + column_name, unary_operator + expr, expr + PIPE2 + expr, expr + (STAR | SLASH | PERCENT) + expr, @@ -674,7 +674,7 @@ def expr(): expr + AND + expr, expr + OR + expr, seq( - name, + function_name, LPAREN, opt((opt(DISTINCT) + comma_list(expr)) | STAR), RPAREN, @@ -683,7 +683,7 @@ def expr(): ), LPAREN + comma_list(expr) + RPAREN, CAST + LPAREN + expr + AS + type_name + RPAREN, - expr + COLLATE + name, + expr + COLLATE + collation_name, expr + opt(NOT) + (LIKE | GLOB | REGEXP | MATCH) + expr + opt(ESCAPE, expr), expr + (ISNULL | NOTNULL | seq(NOT, NULL)), expr + IS + opt(NOT) + expr, @@ -694,10 +694,10 @@ def expr(): IN, alt( LPAREN + opt(select_stmt | comma_list(expr)) + RPAREN, - opt(name, DOT) + name, + opt(schema_name, DOT) + table_name, seq( - opt(name, DOT), - name, + opt(schema_name, DOT), + table_function_name, LPAREN, opt(comma_list(expr)), RPAREN, @@ -748,10 +748,10 @@ def insert_stmt(): opt(with_clause), INSERT | REPLACE | seq(INSERT, OR, REPLACE | ROLLBACK | ABORT | FAIL | IGNORE), INTO, - opt(name, DOT), - name, - opt(AS, name), - opt(LPAREN, comma_list(name), RPAREN), + opt(schema_name, DOT), + table_name, + opt(AS, table_alias), + opt(LPAREN, comma_list(column_name), RPAREN), (((values_clause | select_stmt) + opt(upsert_clause)) | seq(DEFAULT, VALUES)), opt(returning_clause), ) @@ -774,7 +774,7 @@ def upsert_clause(): seq( UPDATE, SET, - comma_list((name | column_name_list), EQUAL, expr), + comma_list((column_name | column_name_list), EQUAL, expr), opt(WHERE, expr), ), ), @@ -785,8 +785,8 @@ def upsert_clause(): def pragma_stmt(): return seq( PRAGMA, - opt(name, DOT), - name, + opt(schema_name, DOT), + pragma_name, opt((EQUAL + pragma_value) | (LPAREN + pragma_value + RPAREN)), ) @@ -798,7 +798,7 @@ def pragma_value(): @rule def reindex_stmt(): - return REINDEX + opt(name | (opt(name, DOT) + (name | name))) + return REINDEX + opt(collation_name | (opt(schema_name, DOT) + (table_name | index_name))) @rule @@ -827,7 +827,7 @@ def select_core(): opt(FROM, comma_list(table_or_subquery) | join_clause), opt(WHERE, expr), opt(GROUP, BY, comma_list(expr), opt(HAVING, expr)), - opt(WINDOW, comma_list(name, AS, window_defn)), + opt(WINDOW, comma_list(window_name, AS, window_defn)), ), values_clause, ) @@ -858,27 +858,27 @@ def compound_select_stmt(): def table_or_subquery(): return alt( seq( - opt(name, DOT), - name, - opt(opt(AS), name), - opt(seq(INDEXED, BY, name) | (NOT + INDEXED)), + opt(schema_name, DOT), + table_name, + opt(opt(AS), table_alias), + opt(seq(INDEXED, BY, index_name) | (NOT + INDEXED)), ), seq( - opt(name, DOT), - name, + opt(schema_name, DOT), + table_function_name, LPAREN, comma_list(expr), RPAREN, - opt(AS, name), + opt(AS, table_alias), ), seq(LPAREN, comma_list(table_or_subquery) | join_clause, RPAREN), - seq(LPAREN, select_stmt, RPAREN, opt(opt(AS), name)), + seq(LPAREN, select_stmt, RPAREN, opt(opt(AS), table_alias)), ) @rule def result_column(): - return STAR | seq(name, DOT, STAR) | seq(expr, opt(opt(AS), column_alias)) + return STAR | seq(table_name, DOT, STAR) | seq(expr, opt(opt(AS), column_alias)) @rule @@ -893,7 +893,7 @@ def join_operator(): def join_constraint(): return alt( ON + expr, - USING + LPAREN + comma_list(name) + RPAREN, + USING + LPAREN + comma_list(column_name) + RPAREN, ) @@ -910,7 +910,7 @@ def update_stmt(): opt(OR, ROLLBACK | ABORT | REPLACE | FAIL | IGNORE), qualified_table_name, SET, - comma_list(name | column_name_list, EQUAL, expr), + comma_list(column_name | column_name_list, EQUAL, expr), opt(FROM, comma_list(table_or_subquery) | join_clause), opt(WHERE, expr), opt(returning_clause), @@ -919,7 +919,7 @@ def update_stmt(): @rule def column_name_list(): - return LPAREN + comma_list(name) + RPAREN + return LPAREN + comma_list(column_name) + RPAREN @rule @@ -930,7 +930,7 @@ def update_stmt_limited(): opt(OR, ROLLBACK | ABORT | REPLACE | FAIL | IGNORE), qualified_table_name, SET, - comma_list(name | column_name_list, EQUAL, expr), + comma_list(column_name | column_name_list, EQUAL, expr), opt(WHERE, expr), opt(returning_clause), opt(opt(order_by_stmt), limit_stmt), @@ -940,16 +940,16 @@ def update_stmt_limited(): @rule def qualified_table_name(): return seq( - opt(name, DOT), - name, - opt(AS, name), - opt(INDEXED + BY + name | NOT + INDEXED), + opt(schema_name, DOT), + table_name, + opt(AS, alias), + opt(INDEXED + BY + index_name | NOT + INDEXED), ) @rule def vacuum_stmt(): - return VACUUM + opt(name) + opt(INTO, name) + return VACUUM + opt(schema_name) + opt(INTO, filename) @rule @@ -961,7 +961,7 @@ def filter_clause(): def window_defn(): return seq( LPAREN, - opt(name), + opt(base_window_name), opt(PARTITION, BY, comma_list(expr)), ORDER, BY, @@ -976,10 +976,10 @@ def over_clause(): return seq( OVER, alt( - name, + window_name, seq( LPAREN, - opt(name), + opt(base_window_name), opt(PARTITION, BY, comma_list(expr)), opt(ORDER, BY, comma_list(ordering_term)), opt(frame_spec), @@ -1004,13 +1004,13 @@ def frame_clause(): @rule def simple_function_invocation(): - return seq(name, LPAREN, comma_list(expr) | STAR, RPAREN) + return seq(simple_func, LPAREN, comma_list(expr) | STAR, RPAREN) @rule def aggregate_function_invocation(): return seq( - name, + aggregate_func, LPAREN, opt(opt(DISTINCT), comma_list(expr) | STAR), RPAREN, @@ -1027,7 +1027,7 @@ def window_function_invocation(): RPAREN, opt(filter_clause), OVER, - window_defn | name, + window_defn | window_name, ) @@ -1053,7 +1053,7 @@ LAST = Terminal("LAST", "last") @rule def ordering_term(): - return seq(expr, opt(COLLATE, name), opt(asc_desc), opt(NULLS, FIRST | LAST)) + return seq(expr, opt(COLLATE, collation_name), opt(asc_desc), opt(NULLS, FIRST | LAST)) @rule @@ -1364,7 +1364,122 @@ def keyword(): @rule def name(): - return IDENTIFIER | keyword | STRING_LITERAL | seq(LPAREN, name, RPAREN) + return any_name + + +@rule +def function_name(): + return any_name + + +@rule +def schema_name(): + return any_name + + +@rule +def table_name(): + return any_name + + +@rule +def table_or_index_name(): + return any_name + + +@rule +def column_name(): + return any_name + + +@rule +def collation_name(): + return any_name + + +@rule +def foreign_table(): + return any_name + + +@rule +def index_name(): + return any_name + + +@rule +def trigger_name(): + return any_name + + +@rule +def view_name(): + return any_name + + +@rule +def module_name(): + return any_name + + +@rule +def pragma_name(): + return any_name + + +@rule +def savepoint_name(): + return any_name + + +@rule +def table_alias(): + return any_name + + +@rule +def transaction_name(): + return any_name + + +@rule +def window_name(): + return any_name + + +@rule +def alias(): + return any_name + + +@rule +def filename(): + return any_name + + +@rule +def base_window_name(): + return any_name + + +@rule +def simple_func(): + return any_name + + +@rule +def aggregate_func(): + return any_name + + +@rule +def table_function_name(): + return any_name + + +@rule +def any_name(): + return IDENTIFIER | keyword | STRING_LITERAL | seq(LPAREN, any_name, RPAREN) SQL = Grammar( @@ -1373,6 +1488,7 @@ SQL = Grammar( (Assoc.LEFT, [OR]), (Assoc.LEFT, [AND]), (Assoc.LEFT, [NOT]), + (Assoc.LEFT, []), (Assoc.LEFT, [PLUS, MINUS]), (Assoc.LEFT, [STAR, SLASH]), # TODO: Unary minus @@ -1381,184 +1497,13 @@ SQL = Grammar( name="SQL", ) - -def emit_yacc(path: str, grammar: Grammar): - lines = [] - token_names = [t.name for t in grammar.terminals()] - token_names.sort() - - trivia = {t.name for t in grammar.trivia_terminals()} - - buf = "" - for tn in token_names: - if tn in trivia: - continue - - if len(buf) > 0: - buf += " " - buf += tn - if len(buf) >= 73: - lines.append(f"%token {buf}") - buf = "" - if len(buf) > 0: - lines.append(f"%token {buf}") - lines.append("") - - prec = grammar.precedence() - if len(prec) > 0: - for assoc, rules in prec: - match assoc: - case Assoc.LEFT: - line = "%left " - case Assoc.RIGHT: - line = "%right " - case Assoc.NONE: - line = "%nonassoc" - case _: - typing.assert_never(assoc) - - rns = " ".join([rule.name for rule in rules]) - lines.append(f"{line} {rns}") - lines.append("") - - lines.append(f"%start {grammar.start.name}") - lines.append("") - - lines.append("%%") - for nt in grammar.non_terminals(): - for rule in nt.body: - prod = " ".join([s.name for s in rule]) - lines.append(f"{nt.name}: {prod};") - lines.append("") - lines.append("%%") - - with open(path, "w", encoding="utf8") as file: - file.writelines([f"{l}\n" for l in lines]) - - -def emit_lex(path: str, grammar: Grammar): - def to_js_string(s: str) -> str: - result = json.dumps(s)[1:-1] - # JSON escapes double-quotes but we don't need to in our context. - result = result.replace('\\"', '"') - return result - - def to_lex_regex(re: parser.Re) -> str: - # NOTE: In general it's bad to introduce parenthesis into regular - # expressions where they're not required because they also create - # capture groups, but I think it doesn't apply to tree-sitter - # regular expressions (and it doesn't mean anything to me either.) - if isinstance(re, parser.ReSeq): - final = [] - queue = [] - queue.append(re) - while len(queue) > 0: - part = queue.pop() - if isinstance(part, parser.ReSeq): - queue.append(part.right) - queue.append(part.left) - else: - final.append(part) - - s = "".join([to_lex_regex(p) for p in final]) - if len(final) > 1: - s = f"({s})" - return s - - elif isinstance(re, parser.ReAlt): - final = [] - queue = [] - queue.append(re) - while len(queue) > 0: - part = queue.pop() - if isinstance(part, parser.ReAlt): - queue.append(part.right) - queue.append(part.left) - else: - final.append(part) - - s = "|".join([to_lex_regex(p) for p in final]) - if len(final) > 1: - s = f"({s})" - return s - - elif isinstance(re, parser.ReQuestion): - s = to_lex_regex(re.child) - return f"({s})?" - - elif isinstance(re, parser.RePlus): - s = to_lex_regex(re.child) - return f"({s})+" - - elif isinstance(re, parser.ReStar): - s = to_lex_regex(re.child) - return f"({s})*" - - elif isinstance(re, parser.ReSet): - if ( - len(re.values) == 1 - and re.values[0].lower == 0 - and re.values[0].upper == parser.UNICODE_MAX_CP - ): - return "." - - inverted = re.inversion - if inverted: - re = re.invert() - - parts = [] - for value in re.values: - if len(value) == 1: - parts.append(to_js_string(chr(value.lower))) - else: - parts.append( - "{}-{}".format( - to_js_string(chr(value.lower)), - to_js_string(chr(value.upper - 1)), - ) - ) - - s = "".join(parts) - if inverted: - s = "^" + s - if len(s) > 1: - # The only time this isn't a "set" is if this is a set of one - # range that is one character long, in which case it's better - # represented as a literal. - s = f"[{s}]" - # else: - # s = s.replace("'", "\\'") - # s = f"'{s}'" - return s - - raise Exception(f"Regex node {re} not supported for tree-sitter") - - lines = ["%%"] - trivia = {t.name for t in grammar.trivia_terminals()} - - for terminal in grammar.terminals(): - if isinstance(terminal.pattern, str): - pattern = terminal.pattern - else: - pattern = to_lex_regex(terminal.pattern) - - name = ";" if terminal.name in trivia else f'"{terminal.name}"' - lines.append(f"{pattern} {name}") - - with open(path, "w", encoding="utf8") as file: - file.writelines([f"{l}\n" for l in lines]) - - if __name__ == "__main__": - # import cProfile + import cProfile - # print("Starting...") - # with cProfile.Profile() as pr: - # try: - # SQL.build_table() - # finally: - # pr.dump_stats("sql.pprof") - # print("Wrote output to sql.pprof") - - emit_yacc("sql.y", SQL) - emit_lex("sql.l", SQL) + print("Starting...") + with cProfile.Profile() as pr: + try: + SQL.build_table() + finally: + pr.dump_stats("sql.pprof") + print("Wrote output to sql.pprof") diff --git a/parser/parser.py b/parser/parser.py index 7b5c18e..cc71f82 100644 --- a/parser/parser.py +++ b/parser/parser.py @@ -2978,7 +2978,6 @@ class Grammar: _nonterminals: dict[str, NonTerminal] _trivia: list[Terminal] _precedence: dict[str, typing.Tuple[Assoc, int]] - _preclist: PrecedenceList def __init__( self, @@ -2995,7 +2994,6 @@ class Grammar: if precedence is None: precedence = [] assert precedence is not None - self._preclist = precedence if trivia is None: trivia = [] @@ -3029,9 +3027,6 @@ class Grammar: def get_precedence(self, name: str) -> None | tuple[Assoc, int]: return self._precedence.get(name) - def precedence(self) -> PrecedenceList: - return self._preclist - def desugar(self) -> typing.Tuple[list[typing.Tuple[str, list[str]]], set[str]]: """Convert the rules into a flat list of productions.