diff --git a/examples/sql.py b/examples/sql.py index a5afb9f..a28adcc 100644 --- a/examples/sql.py +++ b/examples/sql.py @@ -315,29 +315,29 @@ def alter_table_stmt(): return ( ALTER + TABLE - + opt(schema_name + DOT) - + table_name + + opt(name + DOT) + + name + alt( - RENAME + alt((TO + table_name), (COLUMN + column_name + TO + column_name)), + RENAME + alt((TO + name), (COLUMN + name + TO + name)), (ADD + opt(COLUMN) + column_def), - (DROP + opt(COLUMN) + column_name), + (DROP + opt(COLUMN) + name), ) ) @rule def analyze_stmt(): - return ANALYZE + opt(alt(schema_name, opt(schema_name + DOT) + table_or_index_name)) + return ANALYZE + opt(opt(name + DOT) + name) @rule def attach_stmt(): - return ATTACH + opt(DATABASE) + expr + AS + schema_name + return ATTACH + opt(DATABASE) + expr + AS + name @rule def begin_stmt(): - return BEGIN + opt(DEFERRED | IMMEDIATE | EXCLUSIVE) + opt(TRANSACTION + opt(transaction_name)) + return BEGIN + opt(DEFERRED | IMMEDIATE | EXCLUSIVE) + opt(TRANSACTION + opt(name)) @rule @@ -347,17 +347,17 @@ def commit_stmt(): @rule def rollback_stmt(): - return ROLLBACK + opt(TRANSACTION) + opt(TO + opt(SAVEPOINT) + savepoint_name) + return ROLLBACK + opt(TRANSACTION) + opt(TO + opt(SAVEPOINT) + name) @rule def savepoint_stmt(): - return SAVEPOINT + savepoint_name + return SAVEPOINT + name @rule def release_stmt(): - return RELEASE + opt(SAVEPOINT) + savepoint_name + return RELEASE + opt(SAVEPOINT) + name def comma_list(*rules: Rule) -> Rule: @@ -373,10 +373,10 @@ def create_index_stmt(): opt(UNIQUE), INDEX, opt(IF + NOT + EXISTS), - opt(schema_name + DOT), - index_name, + opt(name + DOT), + name, ON, - table_name, + name, LPAREN, comma_list(indexed_column), RPAREN, @@ -386,7 +386,7 @@ def create_index_stmt(): @rule def indexed_column(): - return (column_name | expr) + opt(COLLATE + collation_name) + opt(asc_desc) + return (name | expr) + opt(COLLATE + name) + opt(asc_desc) @rule @@ -396,8 +396,8 @@ def create_table_stmt(): opt(TEMP | TEMPORARY), TABLE, opt(IF, NOT, EXISTS), - opt(schema_name, DOT), - table_name, + opt(name, DOT), + name, alt( seq( LPAREN, @@ -413,7 +413,7 @@ def create_table_stmt(): @rule def column_def(): - return column_name + opt(type_name) + zero_or_more(column_constraint) + return name + opt(type_name) + zero_or_more(column_constraint) @rule @@ -429,7 +429,7 @@ def column_constraint(): seq(PRIMARY, KEY, opt(asc_desc), opt(conflict_clause), opt(AUTOINCREMENT)), seq(opt(NOT), (NULL | UNIQUE), opt(conflict_clause)), seq(DEFAULT, signed_number | literal_value | seq(LPAREN, expr, RPAREN)), - seq(COLLATE, collation_name), + seq(COLLATE, name), foreign_key_clause, seq(opt(GENERATED, ALWAYS), AS, LPAREN, expr, RPAREN, opt(STORED | VIRTUAL)), ), @@ -458,7 +458,7 @@ def table_constraint(): FOREIGN, KEY, LPAREN, - comma_list(column_name), + comma_list(name), RPAREN, foreign_key_clause, ), @@ -470,8 +470,8 @@ def table_constraint(): def foreign_key_clause(): return seq( REFERENCES, - foreign_table, - opt(LPAREN, comma_list(column_name), RPAREN), + name, + opt(LPAREN, comma_list(name), RPAREN), zero_or_more( alt( seq( @@ -503,12 +503,12 @@ def create_trigger_stmt(): opt(TEMP | TEMPORARY), TRIGGER, opt(IF, NOT, EXISTS), - opt(schema_name, DOT), - trigger_name, + opt(name, DOT), + name, opt(BEFORE | AFTER | (INSTEAD + OF)), - (DELETE | INSERT | (UPDATE + opt(OF, comma_list(column_name)))), + (DELETE | INSERT | (UPDATE + opt(OF, comma_list(name)))), ON, - table_name, + name, opt(FOR, EACH, ROW), opt(WHEN, expr), BEGIN, @@ -524,9 +524,9 @@ def create_view_stmt(): opt(TEMP | TEMPORARY), VIEW, opt(IF, NOT, EXISTS), - opt(schema_name, DOT), - view_name, - opt(LPAREN, comma_list(column_name), RPAREN), + opt(name, DOT), + name, + opt(LPAREN, comma_list(name), RPAREN), AS, select_stmt, ) @@ -539,10 +539,10 @@ def create_virtual_table_stmt(): VIRTUAL, TABLE, opt(IF, NOT, EXISTS), - opt(schema_name, DOT), - table_name, + opt(name, DOT), + name, USING, - module_name, + name, opt(LPAREN, comma_list(module_argument), RPAREN), ) @@ -558,7 +558,7 @@ def with_clause(): @rule def cte_table_name(): - return table_name + opt(LPAREN, comma_list(column_name), RPAREN) + return name + opt(LPAREN, comma_list(name), RPAREN) @rule @@ -578,8 +578,8 @@ def recursive_cte(): @rule def common_table_expression(): return seq( - table_name, - opt(LPAREN, comma_list(column_name), RPAREN), + name, + opt(LPAREN, comma_list(name), RPAREN), AS, LPAREN, select_stmt, @@ -617,7 +617,7 @@ def delete_stmt_limited(): @rule def detach_stmt(): - return DETACH + opt(DATABASE) + schema_name + return DETACH + opt(DATABASE) + name @rule @@ -626,8 +626,8 @@ def drop_stmt(): DROP, (INDEX | TABLE | TRIGGER | VIEW), opt(IF, EXISTS), - opt(schema_name, DOT), - any_name, + opt(name, DOT), + name, ) @@ -647,7 +647,7 @@ def expr(): return alt( literal_value, BIND_PARAMETER, - opt(opt(schema_name, DOT), table_name, DOT) + column_name, + opt(opt(name, DOT), name, DOT) + name, unary_operator + expr, expr + PIPE2 + expr, expr + (STAR | SLASH | PERCENT) + expr, @@ -674,7 +674,7 @@ def expr(): expr + AND + expr, expr + OR + expr, seq( - function_name, + name, LPAREN, opt((opt(DISTINCT) + comma_list(expr)) | STAR), RPAREN, @@ -683,7 +683,7 @@ def expr(): ), LPAREN + comma_list(expr) + RPAREN, CAST + LPAREN + expr + AS + type_name + RPAREN, - expr + COLLATE + collation_name, + expr + COLLATE + name, expr + opt(NOT) + (LIKE | GLOB | REGEXP | MATCH) + expr + opt(ESCAPE, expr), expr + (ISNULL | NOTNULL | seq(NOT, NULL)), expr + IS + opt(NOT) + expr, @@ -694,10 +694,10 @@ def expr(): IN, alt( LPAREN + opt(select_stmt | comma_list(expr)) + RPAREN, - opt(schema_name, DOT) + table_name, + opt(name, DOT) + name, seq( - opt(schema_name, DOT), - table_function_name, + opt(name, DOT), + name, LPAREN, opt(comma_list(expr)), RPAREN, @@ -748,10 +748,10 @@ def insert_stmt(): opt(with_clause), INSERT | REPLACE | seq(INSERT, OR, REPLACE | ROLLBACK | ABORT | FAIL | IGNORE), INTO, - opt(schema_name, DOT), - table_name, - opt(AS, table_alias), - opt(LPAREN, comma_list(column_name), RPAREN), + opt(name, DOT), + name, + opt(AS, name), + opt(LPAREN, comma_list(name), RPAREN), (((values_clause | select_stmt) + opt(upsert_clause)) | seq(DEFAULT, VALUES)), opt(returning_clause), ) @@ -774,7 +774,7 @@ def upsert_clause(): seq( UPDATE, SET, - comma_list((column_name | column_name_list), EQUAL, expr), + comma_list((name | column_name_list), EQUAL, expr), opt(WHERE, expr), ), ), @@ -785,8 +785,8 @@ def upsert_clause(): def pragma_stmt(): return seq( PRAGMA, - opt(schema_name, DOT), - pragma_name, + opt(name, DOT), + name, opt((EQUAL + pragma_value) | (LPAREN + pragma_value + RPAREN)), ) @@ -798,7 +798,7 @@ def pragma_value(): @rule def reindex_stmt(): - return REINDEX + opt(collation_name | (opt(schema_name, DOT) + (table_name | index_name))) + return REINDEX + opt(name | (opt(name, DOT) + (name | name))) @rule @@ -827,7 +827,7 @@ def select_core(): opt(FROM, comma_list(table_or_subquery) | join_clause), opt(WHERE, expr), opt(GROUP, BY, comma_list(expr), opt(HAVING, expr)), - opt(WINDOW, comma_list(window_name, AS, window_defn)), + opt(WINDOW, comma_list(name, AS, window_defn)), ), values_clause, ) @@ -858,27 +858,27 @@ def compound_select_stmt(): def table_or_subquery(): return alt( seq( - opt(schema_name, DOT), - table_name, - opt(opt(AS), table_alias), - opt(seq(INDEXED, BY, index_name) | (NOT + INDEXED)), + opt(name, DOT), + name, + opt(opt(AS), name), + opt(seq(INDEXED, BY, name) | (NOT + INDEXED)), ), seq( - opt(schema_name, DOT), - table_function_name, + opt(name, DOT), + name, LPAREN, comma_list(expr), RPAREN, - opt(AS, table_alias), + opt(AS, name), ), seq(LPAREN, comma_list(table_or_subquery) | join_clause, RPAREN), - seq(LPAREN, select_stmt, RPAREN, opt(opt(AS), table_alias)), + seq(LPAREN, select_stmt, RPAREN, opt(opt(AS), name)), ) @rule def result_column(): - return STAR | seq(table_name, DOT, STAR) | seq(expr, opt(opt(AS), column_alias)) + return STAR | seq(name, DOT, STAR) | seq(expr, opt(opt(AS), column_alias)) @rule @@ -893,7 +893,7 @@ def join_operator(): def join_constraint(): return alt( ON + expr, - USING + LPAREN + comma_list(column_name) + RPAREN, + USING + LPAREN + comma_list(name) + RPAREN, ) @@ -910,7 +910,7 @@ def update_stmt(): opt(OR, ROLLBACK | ABORT | REPLACE | FAIL | IGNORE), qualified_table_name, SET, - comma_list(column_name | column_name_list, EQUAL, expr), + comma_list(name | column_name_list, EQUAL, expr), opt(FROM, comma_list(table_or_subquery) | join_clause), opt(WHERE, expr), opt(returning_clause), @@ -919,7 +919,7 @@ def update_stmt(): @rule def column_name_list(): - return LPAREN + comma_list(column_name) + RPAREN + return LPAREN + comma_list(name) + RPAREN @rule @@ -930,7 +930,7 @@ def update_stmt_limited(): opt(OR, ROLLBACK | ABORT | REPLACE | FAIL | IGNORE), qualified_table_name, SET, - comma_list(column_name | column_name_list, EQUAL, expr), + comma_list(name | column_name_list, EQUAL, expr), opt(WHERE, expr), opt(returning_clause), opt(opt(order_by_stmt), limit_stmt), @@ -940,16 +940,16 @@ def update_stmt_limited(): @rule def qualified_table_name(): return seq( - opt(schema_name, DOT), - table_name, - opt(AS, alias), - opt(INDEXED + BY + index_name | NOT + INDEXED), + opt(name, DOT), + name, + opt(AS, name), + opt(INDEXED + BY + name | NOT + INDEXED), ) @rule def vacuum_stmt(): - return VACUUM + opt(schema_name) + opt(INTO, filename) + return VACUUM + opt(name) + opt(INTO, name) @rule @@ -961,7 +961,7 @@ def filter_clause(): def window_defn(): return seq( LPAREN, - opt(base_window_name), + opt(name), opt(PARTITION, BY, comma_list(expr)), ORDER, BY, @@ -976,10 +976,10 @@ def over_clause(): return seq( OVER, alt( - window_name, + name, seq( LPAREN, - opt(base_window_name), + opt(name), opt(PARTITION, BY, comma_list(expr)), opt(ORDER, BY, comma_list(ordering_term)), opt(frame_spec), @@ -1004,13 +1004,13 @@ def frame_clause(): @rule def simple_function_invocation(): - return seq(simple_func, LPAREN, comma_list(expr) | STAR, RPAREN) + return seq(name, LPAREN, comma_list(expr) | STAR, RPAREN) @rule def aggregate_function_invocation(): return seq( - aggregate_func, + name, LPAREN, opt(opt(DISTINCT), comma_list(expr) | STAR), RPAREN, @@ -1027,7 +1027,7 @@ def window_function_invocation(): RPAREN, opt(filter_clause), OVER, - window_defn | window_name, + window_defn | name, ) @@ -1053,7 +1053,7 @@ LAST = Terminal("LAST", "last") @rule def ordering_term(): - return seq(expr, opt(COLLATE, collation_name), opt(asc_desc), opt(NULLS, FIRST | LAST)) + return seq(expr, opt(COLLATE, name), opt(asc_desc), opt(NULLS, FIRST | LAST)) @rule @@ -1364,122 +1364,7 @@ def keyword(): @rule def name(): - return any_name - - -@rule -def function_name(): - return any_name - - -@rule -def schema_name(): - return any_name - - -@rule -def table_name(): - return any_name - - -@rule -def table_or_index_name(): - return any_name - - -@rule -def column_name(): - return any_name - - -@rule -def collation_name(): - return any_name - - -@rule -def foreign_table(): - return any_name - - -@rule -def index_name(): - return any_name - - -@rule -def trigger_name(): - return any_name - - -@rule -def view_name(): - return any_name - - -@rule -def module_name(): - return any_name - - -@rule -def pragma_name(): - return any_name - - -@rule -def savepoint_name(): - return any_name - - -@rule -def table_alias(): - return any_name - - -@rule -def transaction_name(): - return any_name - - -@rule -def window_name(): - return any_name - - -@rule -def alias(): - return any_name - - -@rule -def filename(): - return any_name - - -@rule -def base_window_name(): - return any_name - - -@rule -def simple_func(): - return any_name - - -@rule -def aggregate_func(): - return any_name - - -@rule -def table_function_name(): - return any_name - - -@rule -def any_name(): - return IDENTIFIER | keyword | STRING_LITERAL | seq(LPAREN, any_name, RPAREN) + return IDENTIFIER | keyword | STRING_LITERAL | seq(LPAREN, name, RPAREN) SQL = Grammar( @@ -1488,7 +1373,6 @@ SQL = Grammar( (Assoc.LEFT, [OR]), (Assoc.LEFT, [AND]), (Assoc.LEFT, [NOT]), - (Assoc.LEFT, []), (Assoc.LEFT, [PLUS, MINUS]), (Assoc.LEFT, [STAR, SLASH]), # TODO: Unary minus @@ -1497,13 +1381,184 @@ SQL = Grammar( name="SQL", ) -if __name__ == "__main__": - import cProfile - print("Starting...") - with cProfile.Profile() as pr: - try: - SQL.build_table() - finally: - pr.dump_stats("sql.pprof") - print("Wrote output to sql.pprof") +def emit_yacc(path: str, grammar: Grammar): + lines = [] + token_names = [t.name for t in grammar.terminals()] + token_names.sort() + + trivia = {t.name for t in grammar.trivia_terminals()} + + buf = "" + for tn in token_names: + if tn in trivia: + continue + + if len(buf) > 0: + buf += " " + buf += tn + if len(buf) >= 73: + lines.append(f"%token {buf}") + buf = "" + if len(buf) > 0: + lines.append(f"%token {buf}") + lines.append("") + + prec = grammar.precedence() + if len(prec) > 0: + for assoc, rules in prec: + match assoc: + case Assoc.LEFT: + line = "%left " + case Assoc.RIGHT: + line = "%right " + case Assoc.NONE: + line = "%nonassoc" + case _: + typing.assert_never(assoc) + + rns = " ".join([rule.name for rule in rules]) + lines.append(f"{line} {rns}") + lines.append("") + + lines.append(f"%start {grammar.start.name}") + lines.append("") + + lines.append("%%") + for nt in grammar.non_terminals(): + for rule in nt.body: + prod = " ".join([s.name for s in rule]) + lines.append(f"{nt.name}: {prod};") + lines.append("") + lines.append("%%") + + with open(path, "w", encoding="utf8") as file: + file.writelines([f"{l}\n" for l in lines]) + + +def emit_lex(path: str, grammar: Grammar): + def to_js_string(s: str) -> str: + result = json.dumps(s)[1:-1] + # JSON escapes double-quotes but we don't need to in our context. + result = result.replace('\\"', '"') + return result + + def to_lex_regex(re: parser.Re) -> str: + # NOTE: In general it's bad to introduce parenthesis into regular + # expressions where they're not required because they also create + # capture groups, but I think it doesn't apply to tree-sitter + # regular expressions (and it doesn't mean anything to me either.) + if isinstance(re, parser.ReSeq): + final = [] + queue = [] + queue.append(re) + while len(queue) > 0: + part = queue.pop() + if isinstance(part, parser.ReSeq): + queue.append(part.right) + queue.append(part.left) + else: + final.append(part) + + s = "".join([to_lex_regex(p) for p in final]) + if len(final) > 1: + s = f"({s})" + return s + + elif isinstance(re, parser.ReAlt): + final = [] + queue = [] + queue.append(re) + while len(queue) > 0: + part = queue.pop() + if isinstance(part, parser.ReAlt): + queue.append(part.right) + queue.append(part.left) + else: + final.append(part) + + s = "|".join([to_lex_regex(p) for p in final]) + if len(final) > 1: + s = f"({s})" + return s + + elif isinstance(re, parser.ReQuestion): + s = to_lex_regex(re.child) + return f"({s})?" + + elif isinstance(re, parser.RePlus): + s = to_lex_regex(re.child) + return f"({s})+" + + elif isinstance(re, parser.ReStar): + s = to_lex_regex(re.child) + return f"({s})*" + + elif isinstance(re, parser.ReSet): + if ( + len(re.values) == 1 + and re.values[0].lower == 0 + and re.values[0].upper == parser.UNICODE_MAX_CP + ): + return "." + + inverted = re.inversion + if inverted: + re = re.invert() + + parts = [] + for value in re.values: + if len(value) == 1: + parts.append(to_js_string(chr(value.lower))) + else: + parts.append( + "{}-{}".format( + to_js_string(chr(value.lower)), + to_js_string(chr(value.upper - 1)), + ) + ) + + s = "".join(parts) + if inverted: + s = "^" + s + if len(s) > 1: + # The only time this isn't a "set" is if this is a set of one + # range that is one character long, in which case it's better + # represented as a literal. + s = f"[{s}]" + # else: + # s = s.replace("'", "\\'") + # s = f"'{s}'" + return s + + raise Exception(f"Regex node {re} not supported for tree-sitter") + + lines = ["%%"] + trivia = {t.name for t in grammar.trivia_terminals()} + + for terminal in grammar.terminals(): + if isinstance(terminal.pattern, str): + pattern = terminal.pattern + else: + pattern = to_lex_regex(terminal.pattern) + + name = ";" if terminal.name in trivia else f'"{terminal.name}"' + lines.append(f"{pattern} {name}") + + with open(path, "w", encoding="utf8") as file: + file.writelines([f"{l}\n" for l in lines]) + + +if __name__ == "__main__": + # import cProfile + + # print("Starting...") + # with cProfile.Profile() as pr: + # try: + # SQL.build_table() + # finally: + # pr.dump_stats("sql.pprof") + # print("Wrote output to sql.pprof") + + emit_yacc("sql.y", SQL) + emit_lex("sql.l", SQL)