[sql] hacking, emit lex and yacc for comparison

This commit is contained in:
John Doty 2025-01-04 08:57:58 -08:00
parent fa7514dc21
commit 8135899abf

View file

@ -315,29 +315,29 @@ def alter_table_stmt():
return (
ALTER
+ TABLE
+ opt(schema_name + DOT)
+ table_name
+ opt(name + DOT)
+ name
+ alt(
RENAME + alt((TO + table_name), (COLUMN + column_name + TO + column_name)),
RENAME + alt((TO + name), (COLUMN + name + TO + name)),
(ADD + opt(COLUMN) + column_def),
(DROP + opt(COLUMN) + column_name),
(DROP + opt(COLUMN) + name),
)
)
@rule
def analyze_stmt():
return ANALYZE + opt(alt(schema_name, opt(schema_name + DOT) + table_or_index_name))
return ANALYZE + opt(opt(name + DOT) + name)
@rule
def attach_stmt():
return ATTACH + opt(DATABASE) + expr + AS + schema_name
return ATTACH + opt(DATABASE) + expr + AS + name
@rule
def begin_stmt():
return BEGIN + opt(DEFERRED | IMMEDIATE | EXCLUSIVE) + opt(TRANSACTION + opt(transaction_name))
return BEGIN + opt(DEFERRED | IMMEDIATE | EXCLUSIVE) + opt(TRANSACTION + opt(name))
@rule
@ -347,17 +347,17 @@ def commit_stmt():
@rule
def rollback_stmt():
return ROLLBACK + opt(TRANSACTION) + opt(TO + opt(SAVEPOINT) + savepoint_name)
return ROLLBACK + opt(TRANSACTION) + opt(TO + opt(SAVEPOINT) + name)
@rule
def savepoint_stmt():
return SAVEPOINT + savepoint_name
return SAVEPOINT + name
@rule
def release_stmt():
return RELEASE + opt(SAVEPOINT) + savepoint_name
return RELEASE + opt(SAVEPOINT) + name
def comma_list(*rules: Rule) -> Rule:
@ -373,10 +373,10 @@ def create_index_stmt():
opt(UNIQUE),
INDEX,
opt(IF + NOT + EXISTS),
opt(schema_name + DOT),
index_name,
opt(name + DOT),
name,
ON,
table_name,
name,
LPAREN,
comma_list(indexed_column),
RPAREN,
@ -386,7 +386,7 @@ def create_index_stmt():
@rule
def indexed_column():
return (column_name | expr) + opt(COLLATE + collation_name) + opt(asc_desc)
return (name | expr) + opt(COLLATE + name) + opt(asc_desc)
@rule
@ -396,8 +396,8 @@ def create_table_stmt():
opt(TEMP | TEMPORARY),
TABLE,
opt(IF, NOT, EXISTS),
opt(schema_name, DOT),
table_name,
opt(name, DOT),
name,
alt(
seq(
LPAREN,
@ -413,7 +413,7 @@ def create_table_stmt():
@rule
def column_def():
return column_name + opt(type_name) + zero_or_more(column_constraint)
return name + opt(type_name) + zero_or_more(column_constraint)
@rule
@ -429,7 +429,7 @@ def column_constraint():
seq(PRIMARY, KEY, opt(asc_desc), opt(conflict_clause), opt(AUTOINCREMENT)),
seq(opt(NOT), (NULL | UNIQUE), opt(conflict_clause)),
seq(DEFAULT, signed_number | literal_value | seq(LPAREN, expr, RPAREN)),
seq(COLLATE, collation_name),
seq(COLLATE, name),
foreign_key_clause,
seq(opt(GENERATED, ALWAYS), AS, LPAREN, expr, RPAREN, opt(STORED | VIRTUAL)),
),
@ -458,7 +458,7 @@ def table_constraint():
FOREIGN,
KEY,
LPAREN,
comma_list(column_name),
comma_list(name),
RPAREN,
foreign_key_clause,
),
@ -470,8 +470,8 @@ def table_constraint():
def foreign_key_clause():
return seq(
REFERENCES,
foreign_table,
opt(LPAREN, comma_list(column_name), RPAREN),
name,
opt(LPAREN, comma_list(name), RPAREN),
zero_or_more(
alt(
seq(
@ -503,12 +503,12 @@ def create_trigger_stmt():
opt(TEMP | TEMPORARY),
TRIGGER,
opt(IF, NOT, EXISTS),
opt(schema_name, DOT),
trigger_name,
opt(name, DOT),
name,
opt(BEFORE | AFTER | (INSTEAD + OF)),
(DELETE | INSERT | (UPDATE + opt(OF, comma_list(column_name)))),
(DELETE | INSERT | (UPDATE + opt(OF, comma_list(name)))),
ON,
table_name,
name,
opt(FOR, EACH, ROW),
opt(WHEN, expr),
BEGIN,
@ -524,9 +524,9 @@ def create_view_stmt():
opt(TEMP | TEMPORARY),
VIEW,
opt(IF, NOT, EXISTS),
opt(schema_name, DOT),
view_name,
opt(LPAREN, comma_list(column_name), RPAREN),
opt(name, DOT),
name,
opt(LPAREN, comma_list(name), RPAREN),
AS,
select_stmt,
)
@ -539,10 +539,10 @@ def create_virtual_table_stmt():
VIRTUAL,
TABLE,
opt(IF, NOT, EXISTS),
opt(schema_name, DOT),
table_name,
opt(name, DOT),
name,
USING,
module_name,
name,
opt(LPAREN, comma_list(module_argument), RPAREN),
)
@ -558,7 +558,7 @@ def with_clause():
@rule
def cte_table_name():
return table_name + opt(LPAREN, comma_list(column_name), RPAREN)
return name + opt(LPAREN, comma_list(name), RPAREN)
@rule
@ -578,8 +578,8 @@ def recursive_cte():
@rule
def common_table_expression():
return seq(
table_name,
opt(LPAREN, comma_list(column_name), RPAREN),
name,
opt(LPAREN, comma_list(name), RPAREN),
AS,
LPAREN,
select_stmt,
@ -617,7 +617,7 @@ def delete_stmt_limited():
@rule
def detach_stmt():
return DETACH + opt(DATABASE) + schema_name
return DETACH + opt(DATABASE) + name
@rule
@ -626,8 +626,8 @@ def drop_stmt():
DROP,
(INDEX | TABLE | TRIGGER | VIEW),
opt(IF, EXISTS),
opt(schema_name, DOT),
any_name,
opt(name, DOT),
name,
)
@ -647,7 +647,7 @@ def expr():
return alt(
literal_value,
BIND_PARAMETER,
opt(opt(schema_name, DOT), table_name, DOT) + column_name,
opt(opt(name, DOT), name, DOT) + name,
unary_operator + expr,
expr + PIPE2 + expr,
expr + (STAR | SLASH | PERCENT) + expr,
@ -674,7 +674,7 @@ def expr():
expr + AND + expr,
expr + OR + expr,
seq(
function_name,
name,
LPAREN,
opt((opt(DISTINCT) + comma_list(expr)) | STAR),
RPAREN,
@ -683,7 +683,7 @@ def expr():
),
LPAREN + comma_list(expr) + RPAREN,
CAST + LPAREN + expr + AS + type_name + RPAREN,
expr + COLLATE + collation_name,
expr + COLLATE + name,
expr + opt(NOT) + (LIKE | GLOB | REGEXP | MATCH) + expr + opt(ESCAPE, expr),
expr + (ISNULL | NOTNULL | seq(NOT, NULL)),
expr + IS + opt(NOT) + expr,
@ -694,10 +694,10 @@ def expr():
IN,
alt(
LPAREN + opt(select_stmt | comma_list(expr)) + RPAREN,
opt(schema_name, DOT) + table_name,
opt(name, DOT) + name,
seq(
opt(schema_name, DOT),
table_function_name,
opt(name, DOT),
name,
LPAREN,
opt(comma_list(expr)),
RPAREN,
@ -748,10 +748,10 @@ def insert_stmt():
opt(with_clause),
INSERT | REPLACE | seq(INSERT, OR, REPLACE | ROLLBACK | ABORT | FAIL | IGNORE),
INTO,
opt(schema_name, DOT),
table_name,
opt(AS, table_alias),
opt(LPAREN, comma_list(column_name), RPAREN),
opt(name, DOT),
name,
opt(AS, name),
opt(LPAREN, comma_list(name), RPAREN),
(((values_clause | select_stmt) + opt(upsert_clause)) | seq(DEFAULT, VALUES)),
opt(returning_clause),
)
@ -774,7 +774,7 @@ def upsert_clause():
seq(
UPDATE,
SET,
comma_list((column_name | column_name_list), EQUAL, expr),
comma_list((name | column_name_list), EQUAL, expr),
opt(WHERE, expr),
),
),
@ -785,8 +785,8 @@ def upsert_clause():
def pragma_stmt():
return seq(
PRAGMA,
opt(schema_name, DOT),
pragma_name,
opt(name, DOT),
name,
opt((EQUAL + pragma_value) | (LPAREN + pragma_value + RPAREN)),
)
@ -798,7 +798,7 @@ def pragma_value():
@rule
def reindex_stmt():
return REINDEX + opt(collation_name | (opt(schema_name, DOT) + (table_name | index_name)))
return REINDEX + opt(name | (opt(name, DOT) + (name | name)))
@rule
@ -827,7 +827,7 @@ def select_core():
opt(FROM, comma_list(table_or_subquery) | join_clause),
opt(WHERE, expr),
opt(GROUP, BY, comma_list(expr), opt(HAVING, expr)),
opt(WINDOW, comma_list(window_name, AS, window_defn)),
opt(WINDOW, comma_list(name, AS, window_defn)),
),
values_clause,
)
@ -858,27 +858,27 @@ def compound_select_stmt():
def table_or_subquery():
return alt(
seq(
opt(schema_name, DOT),
table_name,
opt(opt(AS), table_alias),
opt(seq(INDEXED, BY, index_name) | (NOT + INDEXED)),
opt(name, DOT),
name,
opt(opt(AS), name),
opt(seq(INDEXED, BY, name) | (NOT + INDEXED)),
),
seq(
opt(schema_name, DOT),
table_function_name,
opt(name, DOT),
name,
LPAREN,
comma_list(expr),
RPAREN,
opt(AS, table_alias),
opt(AS, name),
),
seq(LPAREN, comma_list(table_or_subquery) | join_clause, RPAREN),
seq(LPAREN, select_stmt, RPAREN, opt(opt(AS), table_alias)),
seq(LPAREN, select_stmt, RPAREN, opt(opt(AS), name)),
)
@rule
def result_column():
return STAR | seq(table_name, DOT, STAR) | seq(expr, opt(opt(AS), column_alias))
return STAR | seq(name, DOT, STAR) | seq(expr, opt(opt(AS), column_alias))
@rule
@ -893,7 +893,7 @@ def join_operator():
def join_constraint():
return alt(
ON + expr,
USING + LPAREN + comma_list(column_name) + RPAREN,
USING + LPAREN + comma_list(name) + RPAREN,
)
@ -910,7 +910,7 @@ def update_stmt():
opt(OR, ROLLBACK | ABORT | REPLACE | FAIL | IGNORE),
qualified_table_name,
SET,
comma_list(column_name | column_name_list, EQUAL, expr),
comma_list(name | column_name_list, EQUAL, expr),
opt(FROM, comma_list(table_or_subquery) | join_clause),
opt(WHERE, expr),
opt(returning_clause),
@ -919,7 +919,7 @@ def update_stmt():
@rule
def column_name_list():
return LPAREN + comma_list(column_name) + RPAREN
return LPAREN + comma_list(name) + RPAREN
@rule
@ -930,7 +930,7 @@ def update_stmt_limited():
opt(OR, ROLLBACK | ABORT | REPLACE | FAIL | IGNORE),
qualified_table_name,
SET,
comma_list(column_name | column_name_list, EQUAL, expr),
comma_list(name | column_name_list, EQUAL, expr),
opt(WHERE, expr),
opt(returning_clause),
opt(opt(order_by_stmt), limit_stmt),
@ -940,16 +940,16 @@ def update_stmt_limited():
@rule
def qualified_table_name():
return seq(
opt(schema_name, DOT),
table_name,
opt(AS, alias),
opt(INDEXED + BY + index_name | NOT + INDEXED),
opt(name, DOT),
name,
opt(AS, name),
opt(INDEXED + BY + name | NOT + INDEXED),
)
@rule
def vacuum_stmt():
return VACUUM + opt(schema_name) + opt(INTO, filename)
return VACUUM + opt(name) + opt(INTO, name)
@rule
@ -961,7 +961,7 @@ def filter_clause():
def window_defn():
return seq(
LPAREN,
opt(base_window_name),
opt(name),
opt(PARTITION, BY, comma_list(expr)),
ORDER,
BY,
@ -976,10 +976,10 @@ def over_clause():
return seq(
OVER,
alt(
window_name,
name,
seq(
LPAREN,
opt(base_window_name),
opt(name),
opt(PARTITION, BY, comma_list(expr)),
opt(ORDER, BY, comma_list(ordering_term)),
opt(frame_spec),
@ -1004,13 +1004,13 @@ def frame_clause():
@rule
def simple_function_invocation():
return seq(simple_func, LPAREN, comma_list(expr) | STAR, RPAREN)
return seq(name, LPAREN, comma_list(expr) | STAR, RPAREN)
@rule
def aggregate_function_invocation():
return seq(
aggregate_func,
name,
LPAREN,
opt(opt(DISTINCT), comma_list(expr) | STAR),
RPAREN,
@ -1027,7 +1027,7 @@ def window_function_invocation():
RPAREN,
opt(filter_clause),
OVER,
window_defn | window_name,
window_defn | name,
)
@ -1053,7 +1053,7 @@ LAST = Terminal("LAST", "last")
@rule
def ordering_term():
return seq(expr, opt(COLLATE, collation_name), opt(asc_desc), opt(NULLS, FIRST | LAST))
return seq(expr, opt(COLLATE, name), opt(asc_desc), opt(NULLS, FIRST | LAST))
@rule
@ -1364,122 +1364,7 @@ def keyword():
@rule
def name():
return any_name
@rule
def function_name():
return any_name
@rule
def schema_name():
return any_name
@rule
def table_name():
return any_name
@rule
def table_or_index_name():
return any_name
@rule
def column_name():
return any_name
@rule
def collation_name():
return any_name
@rule
def foreign_table():
return any_name
@rule
def index_name():
return any_name
@rule
def trigger_name():
return any_name
@rule
def view_name():
return any_name
@rule
def module_name():
return any_name
@rule
def pragma_name():
return any_name
@rule
def savepoint_name():
return any_name
@rule
def table_alias():
return any_name
@rule
def transaction_name():
return any_name
@rule
def window_name():
return any_name
@rule
def alias():
return any_name
@rule
def filename():
return any_name
@rule
def base_window_name():
return any_name
@rule
def simple_func():
return any_name
@rule
def aggregate_func():
return any_name
@rule
def table_function_name():
return any_name
@rule
def any_name():
return IDENTIFIER | keyword | STRING_LITERAL | seq(LPAREN, any_name, RPAREN)
return IDENTIFIER | keyword | STRING_LITERAL | seq(LPAREN, name, RPAREN)
SQL = Grammar(
@ -1488,7 +1373,6 @@ SQL = Grammar(
(Assoc.LEFT, [OR]),
(Assoc.LEFT, [AND]),
(Assoc.LEFT, [NOT]),
(Assoc.LEFT, []),
(Assoc.LEFT, [PLUS, MINUS]),
(Assoc.LEFT, [STAR, SLASH]),
# TODO: Unary minus
@ -1497,13 +1381,184 @@ SQL = Grammar(
name="SQL",
)
if __name__ == "__main__":
import cProfile
print("Starting...")
with cProfile.Profile() as pr:
try:
SQL.build_table()
finally:
pr.dump_stats("sql.pprof")
print("Wrote output to sql.pprof")
def emit_yacc(path: str, grammar: Grammar):
lines = []
token_names = [t.name for t in grammar.terminals()]
token_names.sort()
trivia = {t.name for t in grammar.trivia_terminals()}
buf = ""
for tn in token_names:
if tn in trivia:
continue
if len(buf) > 0:
buf += " "
buf += tn
if len(buf) >= 73:
lines.append(f"%token {buf}")
buf = ""
if len(buf) > 0:
lines.append(f"%token {buf}")
lines.append("")
prec = grammar.precedence()
if len(prec) > 0:
for assoc, rules in prec:
match assoc:
case Assoc.LEFT:
line = "%left "
case Assoc.RIGHT:
line = "%right "
case Assoc.NONE:
line = "%nonassoc"
case _:
typing.assert_never(assoc)
rns = " ".join([rule.name for rule in rules])
lines.append(f"{line} {rns}")
lines.append("")
lines.append(f"%start {grammar.start.name}")
lines.append("")
lines.append("%%")
for nt in grammar.non_terminals():
for rule in nt.body:
prod = " ".join([s.name for s in rule])
lines.append(f"{nt.name}: {prod};")
lines.append("")
lines.append("%%")
with open(path, "w", encoding="utf8") as file:
file.writelines([f"{l}\n" for l in lines])
def emit_lex(path: str, grammar: Grammar):
def to_js_string(s: str) -> str:
result = json.dumps(s)[1:-1]
# JSON escapes double-quotes but we don't need to in our context.
result = result.replace('\\"', '"')
return result
def to_lex_regex(re: parser.Re) -> str:
# NOTE: In general it's bad to introduce parenthesis into regular
# expressions where they're not required because they also create
# capture groups, but I think it doesn't apply to tree-sitter
# regular expressions (and it doesn't mean anything to me either.)
if isinstance(re, parser.ReSeq):
final = []
queue = []
queue.append(re)
while len(queue) > 0:
part = queue.pop()
if isinstance(part, parser.ReSeq):
queue.append(part.right)
queue.append(part.left)
else:
final.append(part)
s = "".join([to_lex_regex(p) for p in final])
if len(final) > 1:
s = f"({s})"
return s
elif isinstance(re, parser.ReAlt):
final = []
queue = []
queue.append(re)
while len(queue) > 0:
part = queue.pop()
if isinstance(part, parser.ReAlt):
queue.append(part.right)
queue.append(part.left)
else:
final.append(part)
s = "|".join([to_lex_regex(p) for p in final])
if len(final) > 1:
s = f"({s})"
return s
elif isinstance(re, parser.ReQuestion):
s = to_lex_regex(re.child)
return f"({s})?"
elif isinstance(re, parser.RePlus):
s = to_lex_regex(re.child)
return f"({s})+"
elif isinstance(re, parser.ReStar):
s = to_lex_regex(re.child)
return f"({s})*"
elif isinstance(re, parser.ReSet):
if (
len(re.values) == 1
and re.values[0].lower == 0
and re.values[0].upper == parser.UNICODE_MAX_CP
):
return "."
inverted = re.inversion
if inverted:
re = re.invert()
parts = []
for value in re.values:
if len(value) == 1:
parts.append(to_js_string(chr(value.lower)))
else:
parts.append(
"{}-{}".format(
to_js_string(chr(value.lower)),
to_js_string(chr(value.upper - 1)),
)
)
s = "".join(parts)
if inverted:
s = "^" + s
if len(s) > 1:
# The only time this isn't a "set" is if this is a set of one
# range that is one character long, in which case it's better
# represented as a literal.
s = f"[{s}]"
# else:
# s = s.replace("'", "\\'")
# s = f"'{s}'"
return s
raise Exception(f"Regex node {re} not supported for tree-sitter")
lines = ["%%"]
trivia = {t.name for t in grammar.trivia_terminals()}
for terminal in grammar.terminals():
if isinstance(terminal.pattern, str):
pattern = terminal.pattern
else:
pattern = to_lex_regex(terminal.pattern)
name = ";" if terminal.name in trivia else f'"{terminal.name}"'
lines.append(f"{pattern} {name}")
with open(path, "w", encoding="utf8") as file:
file.writelines([f"{l}\n" for l in lines])
if __name__ == "__main__":
# import cProfile
# print("Starting...")
# with cProfile.Profile() as pr:
# try:
# SQL.build_table()
# finally:
# pr.dump_stats("sql.pprof")
# print("Wrote output to sql.pprof")
emit_yacc("sql.y", SQL)
emit_lex("sql.l", SQL)