diff --git a/parser/parser.py b/parser/parser.py index a6ff2d0..862d53a 100644 --- a/parser/parser.py +++ b/parser/parser.py @@ -1807,14 +1807,23 @@ def seq(*args: Rule) -> Rule: def opt(*args: Rule) -> Rule: + """Mark a sequence as optional.""" return AlternativeRule(seq(*args), Nothing) def mark(rule: Rule, **kwargs) -> Rule: + """Mark the specified rules with metadata.""" return MetadataRule(rule, kwargs) -def one_or_more(r: Rule) -> Rule: +def one_or_more(*args: Rule) -> Rule: + """Generate a rule that matches a repetition of one or more of the specified + rule. + + The resulting list is transparent, i.e., in the parse tree all of the members + of the list will be in-line with the parent. If you want to name the list + create a named nonterminal to contain it. + """ global _CURRENT_DEFINITION global _CURRENT_GEN_INDEX @@ -1822,7 +1831,7 @@ def one_or_more(r: Rule) -> Rule: def impl() -> Rule: nonlocal tail assert(tail is not None) - return opt(tail) + r + return opt(tail) + seq(*args) tail = NonTerminal( fn=impl, @@ -1833,8 +1842,15 @@ def one_or_more(r: Rule) -> Rule: return tail -def zero_or_more(r:Rule) -> Rule: - return opt(one_or_more(r)) +def zero_or_more(*args: Rule) -> Rule: + """Generate a rule that matches a repetition of zero or more of the specified + rule. + + The resulting list is transparent, i.e., in the parse tree all of the members + of the list will be in-line with the parent. If you want to name the list + create a named nonterminal to contain it. + """ + return opt(one_or_more(*args)) @typing.overload def rule(f: typing.Callable, /) -> NonTerminal: ... @@ -2155,11 +2171,25 @@ class Re: def question(self) -> "Re": return ReQuestion(self) - def __or__(self, value: "Re", /) -> "Re": - return ReAlt(self, value) + def __or__(self, value: "Re | Terminal", /) -> "Re": + if isinstance(value, Re): + other = value + elif isinstance(value.pattern, Re): + other = value.pattern + else: + other = Re.literal(value.pattern) - def __add__(self, value: "Re") -> "Re": - return ReSeq(self, value) + return ReAlt(self, other) + + def __add__(self, value: "Re | Terminal") -> "Re": + if isinstance(value, Re): + other = value + elif isinstance(value.pattern, Re): + other = value.pattern + else: + other = Re.literal(value.pattern) + + return ReSeq(self, other) UNICODE_MAX_CP = 1114112