From fb6da7daa348ad2f17c59c23480da8ca5302c12e Mon Sep 17 00:00:00 2001 From: Vladimir Kalnitsky Date: Wed, 20 Nov 2024 00:53:26 +0400 Subject: [PATCH] Implement optional assignments and proper parameter extraction for if exprs --- grammar/MetaPrompt.g4 | 2 +- python/src/eval.py | 29 ++--- python/src/loader.py | 150 ++++++++++++++++++----- python/src/parse_metaprompt.py | 2 + python/src/parser/MetaPrompt.interp | 2 +- python/src/parser/MetaPrompt.tokens | 1 - python/src/parser/MetaPromptLexer.interp | 4 +- python/src/parser/MetaPromptLexer.py | 76 ++++++------ python/src/parser/MetaPromptLexer.tokens | 1 - python/src/parser/MetaPromptParser.py | 5 +- python/tests/test_loader.py | 121 +++++++++++++++++- python/tests/test_parser.py | 35 ++++-- 12 files changed, 326 insertions(+), 102 deletions(-) diff --git a/grammar/MetaPrompt.g4 b/grammar/MetaPrompt.g4 index 75d81c9..8649edc 100644 --- a/grammar/MetaPrompt.g4 +++ b/grammar/MetaPrompt.g4 @@ -35,7 +35,7 @@ text: CHAR+ ; LB : '['; RB : ']'; -EQ_KW : '=' ; +EQ_KW : '=' | '?=' ; META_PROMPT : [a-zA-Z_]?[a-zA-Z0-9_]* '$' ; COMMENT_KW : '#' ; CHAR : ( ESCAPED | .); diff --git a/python/src/eval.py b/python/src/eval.py index 045fbe5..d9ec1b0 100644 --- a/python/src/eval.py +++ b/python/src/eval.py @@ -2,7 +2,7 @@ from env import Env from runtime import BaseRuntime from typing import AsyncGenerator, List -from loader import extract_variables +from loader import extract_parameter_set from eval_utils.assignment import Assignment from eval_utils.chat_history import serialize_chat_history @@ -86,7 +86,7 @@ async def _eval_ast(ast): parameters = ast["parameters"] module_name = ast["module_name"] loaded_ast = runtime.load_module(module_name) - required_variables = extract_variables(loaded_ast) + required_variables = extract_parameter_set(loaded_ast).required for required in required_variables: if required not in parameters: raise ImportError( @@ -117,18 +117,19 @@ async def _eval_ast(ast): elif ast["type"] == "assign": var_name = ast["name"] value = (await _collect_exprs(ast["exprs"])).strip() - if var_name == "STATUS": - runtime.set_status(value) - elif var_name == "ROLE": - if value not in ALLOWED_ROLES: - raise ValueError( - "ROLE variable must be one of " - + "".join([f"'{role}', " for role in ALLOWED_ROLES]) - + ", you specified: " - + value - ) - yield Assignment("ROLE", value) - env.set(var_name, value) + if ast["required"] or env.get(var_name) is None: + if var_name == "STATUS": + runtime.set_status(value) + elif var_name == "ROLE": + if value not in ALLOWED_ROLES: + raise ValueError( + "ROLE variable must be one of " + + "".join([f"'{role}', " for role in ALLOWED_ROLES]) + + ", you specified: " + + value + ) + yield Assignment("ROLE", value) + env.set(var_name, value) elif ast["type"] == "meta": # Load chat history chat_id = ast["chat"] if "chat" in ast else None diff --git a/python/src/loader.py b/python/src/loader.py index e2d9863..570ab41 100644 --- a/python/src/loader.py +++ b/python/src/loader.py @@ -1,32 +1,122 @@ -def _discover_variables(ast): +from typing import Set + + +class ParameterSet: + def __init__( + self, + required: Set[str] = None, + optional: Set[str] = None, + assigned: Set[str] = None, + ): + if required is None: + required = set() + if optional is None: + optional = set() + if assigned is None: + assigned = set() + self.required = required + self.optional = optional + self.assigned = assigned + + def __eq__(self, other): + return ( + self.required == other.required + and self.optional == other.optional + and self.assigned == other.assigned + ) + + def then(self, other): + """Implements sequential composition of ParameterSets: + params([expr1, expr2]) = params(expr1).then(params(expr2)) + """ + return ParameterSet( + # required + self.required.union(other.required.difference(self.assigned)), + # optional + self.optional.union(other.optional) + .difference(self.required) + .difference(other.required), + # assigned + self.assigned.union(other.assigned), + ) + + def alternative(self, other): + """Implements parallel composition of ParameterSets, that is used + for :if: either of the alternatives can be executed, so we have to be + conservative. + """ + return ParameterSet( + # required: all of the required variables are required, + # we don't know which branch will be chosen + self.required.union(other.required), + # optional: if something is optional in just one of the branches, + # it is optional in the whole expression, but if it required in + # either, it is required in the whole + self.optional.union(other.optional) + .difference(self.required) + .difference(other.required), + # assigned: something must be assigned in both branches for us to + # be sure it is assigned + self.assigned.intersection(other.assigned), + ) + + def assign_var(self, name): + """handles [:name=value]""" + self.assigned.add(name) + + def assign_var_optional(self, name): + """handles [:name?=value]""" + if name not in self.assigned: + self.optional.add(name) + self.assigned.add(name) + + def use_var(self, name): + if name not in self.assigned: + self.required.add(name) + + +def extract_parameter_set(ast): + # TODO: special handling of ROLE, MODEL variables + res = ParameterSet() + if isinstance(ast, list): for node in ast: - yield from _discover_variables(node) - elif isinstance(ast, dict): - if "type" in ast: - # TODO: evaluate both :if branches in parallel, to cover this case: - # [:if foo :then [:bar=baz] :else [:bar]] - # -- [:bar] should be unbound here, because it is unbound in the - # first branch - if ast["type"] == "comment": - return - elif ast["type"] == "var": - if ast["name"] != "MODEL": - yield {"type": "var", "name": ast["name"]} - elif ast["type"] == "assign": - yield {"type": "assign", "name": ast["name"]} - for key in ast: - yield from _discover_variables(ast[key]) - - -def extract_variables(ast): - variables = set() - assigned = set() - for item in _discover_variables(ast): - match item: - case {"name": name, "type": "var"}: - if name not in assigned: - variables.add(name) - case {"name": name, "type": "assign"}: - assigned.add(name) - return variables + res = res.then(extract_parameter_set(node)) + elif isinstance(ast, dict) and "type" in ast: + if ast["type"] == "text": + pass + elif ast["type"] == "metaprompt": + for expr in ast["exprs"]: + res = res.then(extract_parameter_set(expr)) + elif ast["type"] == "var": + res.use_var(ast["name"]) + elif ast["type"] == "use": + for _, expr in ast["exprs"]: + res = res.then(extract_parameter_set(expr)) + elif ast["type"] == "assign": + if ast["required"]: + res.assign_var(ast["name"]) + else: + res.assign_var_optional(ast["name"]) + elif ast["type"] == "meta": + for expr in ast["exprs"]: + res = res.then(extract_parameter_set(expr)) + elif ast["type"] == "exprs": + for expr in ast["exprs"]: + extract_parameter_set(expr, assigned) + elif ast["type"] == "if_then_else": + res = res.then(extract_parameter_set(ast["condition"])).then( + extract_parameter_set(ast["then"]).alternative( + extract_parameter_set(ast["else"]) + ) + ) + else: + raise ValueError( + "extract_parameter_set: unknown AST expression: " + str(ast) + ) + else: + raise ValueError( + "extract_parameter_set: unknown AST expression: " + str(ast) + ) + + return res diff --git a/python/src/parse_metaprompt.py b/python/src/parse_metaprompt.py index c9f9143..a0d0b95 100644 --- a/python/src/parse_metaprompt.py +++ b/python/src/parse_metaprompt.py @@ -161,8 +161,10 @@ def visitMeta_body(self, ctx: MetaPromptParser.Meta_bodyContext): for expr in ctx.exprs(): expr_items = self.visit(expr) exprs.extend(expr_items) + required = ctx.EQ_KW().getText() == "=" # or "?=" return { "type": "assign", + "required": required, "name": var_name, "exprs": _join_text_pieces(exprs), } diff --git a/python/src/parser/MetaPrompt.interp b/python/src/parser/MetaPrompt.interp index dca33fd..a26f9d3 100644 --- a/python/src/parser/MetaPrompt.interp +++ b/python/src/parser/MetaPrompt.interp @@ -2,7 +2,7 @@ token literal names: null '[' ']' -'=' +null null '#' null diff --git a/python/src/parser/MetaPrompt.tokens b/python/src/parser/MetaPrompt.tokens index d7c421c..c96cb56 100644 --- a/python/src/parser/MetaPrompt.tokens +++ b/python/src/parser/MetaPrompt.tokens @@ -11,7 +11,6 @@ ELSE_KW=10 VAR_NAME=11 '['=1 ']'=2 -'='=3 '#'=5 ':if'=8 ':then'=9 diff --git a/python/src/parser/MetaPromptLexer.interp b/python/src/parser/MetaPromptLexer.interp index 761d039..eba2a4e 100644 --- a/python/src/parser/MetaPromptLexer.interp +++ b/python/src/parser/MetaPromptLexer.interp @@ -2,7 +2,7 @@ token literal names: null '[' ']' -'=' +null null '#' null @@ -51,4 +51,4 @@ mode names: DEFAULT_MODE atn: -[4, 0, 11, 110, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 3, 3, 39, 8, 3, 1, 3, 5, 3, 42, 8, 3, 10, 3, 12, 3, 45, 9, 3, 1, 3, 1, 3, 1, 4, 1, 4, 1, 5, 1, 5, 3, 5, 53, 8, 5, 1, 6, 1, 6, 1, 6, 1, 7, 1, 7, 3, 7, 60, 8, 7, 1, 8, 1, 8, 1, 9, 1, 9, 1, 9, 1, 9, 1, 9, 1, 9, 4, 9, 70, 8, 9, 11, 9, 12, 9, 71, 1, 9, 4, 9, 75, 8, 9, 11, 9, 12, 9, 76, 1, 9, 5, 9, 80, 8, 9, 10, 9, 12, 9, 83, 9, 9, 1, 10, 1, 10, 1, 11, 1, 11, 1, 11, 1, 11, 1, 12, 1, 12, 1, 12, 1, 12, 1, 12, 1, 12, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 14, 1, 14, 1, 14, 5, 14, 106, 8, 14, 10, 14, 12, 14, 109, 9, 14, 0, 0, 15, 1, 1, 3, 2, 5, 3, 7, 4, 9, 5, 11, 6, 13, 0, 15, 0, 17, 0, 19, 7, 21, 0, 23, 8, 25, 9, 27, 10, 29, 11, 1, 0, 4, 3, 0, 65, 90, 95, 95, 97, 122, 4, 0, 48, 57, 65, 90, 95, 95, 97, 122, 4, 0, 45, 57, 65, 90, 95, 95, 97, 122, 2, 0, 10, 10, 32, 32, 113, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5, 1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0, 23, 1, 0, 0, 0, 0, 25, 1, 0, 0, 0, 0, 27, 1, 0, 0, 0, 0, 29, 1, 0, 0, 0, 1, 31, 1, 0, 0, 0, 3, 33, 1, 0, 0, 0, 5, 35, 1, 0, 0, 0, 7, 38, 1, 0, 0, 0, 9, 48, 1, 0, 0, 0, 11, 52, 1, 0, 0, 0, 13, 54, 1, 0, 0, 0, 15, 59, 1, 0, 0, 0, 17, 61, 1, 0, 0, 0, 19, 63, 1, 0, 0, 0, 21, 84, 1, 0, 0, 0, 23, 86, 1, 0, 0, 0, 25, 90, 1, 0, 0, 0, 27, 96, 1, 0, 0, 0, 29, 102, 1, 0, 0, 0, 31, 32, 5, 91, 0, 0, 32, 2, 1, 0, 0, 0, 33, 34, 5, 93, 0, 0, 34, 4, 1, 0, 0, 0, 35, 36, 5, 61, 0, 0, 36, 6, 1, 0, 0, 0, 37, 39, 7, 0, 0, 0, 38, 37, 1, 0, 0, 0, 38, 39, 1, 0, 0, 0, 39, 43, 1, 0, 0, 0, 40, 42, 7, 1, 0, 0, 41, 40, 1, 0, 0, 0, 42, 45, 1, 0, 0, 0, 43, 41, 1, 0, 0, 0, 43, 44, 1, 0, 0, 0, 44, 46, 1, 0, 0, 0, 45, 43, 1, 0, 0, 0, 46, 47, 5, 36, 0, 0, 47, 8, 1, 0, 0, 0, 48, 49, 5, 35, 0, 0, 49, 10, 1, 0, 0, 0, 50, 53, 3, 13, 6, 0, 51, 53, 9, 0, 0, 0, 52, 50, 1, 0, 0, 0, 52, 51, 1, 0, 0, 0, 53, 12, 1, 0, 0, 0, 54, 55, 3, 17, 8, 0, 55, 56, 3, 15, 7, 0, 56, 14, 1, 0, 0, 0, 57, 60, 3, 1, 0, 0, 58, 60, 3, 17, 8, 0, 59, 57, 1, 0, 0, 0, 59, 58, 1, 0, 0, 0, 60, 16, 1, 0, 0, 0, 61, 62, 5, 92, 0, 0, 62, 18, 1, 0, 0, 0, 63, 64, 5, 58, 0, 0, 64, 65, 5, 117, 0, 0, 65, 66, 5, 115, 0, 0, 66, 67, 5, 101, 0, 0, 67, 69, 1, 0, 0, 0, 68, 70, 3, 21, 10, 0, 69, 68, 1, 0, 0, 0, 70, 71, 1, 0, 0, 0, 71, 69, 1, 0, 0, 0, 71, 72, 1, 0, 0, 0, 72, 74, 1, 0, 0, 0, 73, 75, 7, 2, 0, 0, 74, 73, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 74, 1, 0, 0, 0, 76, 77, 1, 0, 0, 0, 77, 81, 1, 0, 0, 0, 78, 80, 3, 21, 10, 0, 79, 78, 1, 0, 0, 0, 80, 83, 1, 0, 0, 0, 81, 79, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 20, 1, 0, 0, 0, 83, 81, 1, 0, 0, 0, 84, 85, 7, 3, 0, 0, 85, 22, 1, 0, 0, 0, 86, 87, 5, 58, 0, 0, 87, 88, 5, 105, 0, 0, 88, 89, 5, 102, 0, 0, 89, 24, 1, 0, 0, 0, 90, 91, 5, 58, 0, 0, 91, 92, 5, 116, 0, 0, 92, 93, 5, 104, 0, 0, 93, 94, 5, 101, 0, 0, 94, 95, 5, 110, 0, 0, 95, 26, 1, 0, 0, 0, 96, 97, 5, 58, 0, 0, 97, 98, 5, 101, 0, 0, 98, 99, 5, 108, 0, 0, 99, 100, 5, 115, 0, 0, 100, 101, 5, 101, 0, 0, 101, 28, 1, 0, 0, 0, 102, 103, 5, 58, 0, 0, 103, 107, 7, 0, 0, 0, 104, 106, 7, 1, 0, 0, 105, 104, 1, 0, 0, 0, 106, 109, 1, 0, 0, 0, 107, 105, 1, 0, 0, 0, 107, 108, 1, 0, 0, 0, 108, 30, 1, 0, 0, 0, 109, 107, 1, 0, 0, 0, 9, 0, 38, 43, 52, 59, 71, 76, 81, 107, 0] \ No newline at end of file +[4, 0, 11, 113, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 3, 2, 39, 8, 2, 1, 3, 3, 3, 42, 8, 3, 1, 3, 5, 3, 45, 8, 3, 10, 3, 12, 3, 48, 9, 3, 1, 3, 1, 3, 1, 4, 1, 4, 1, 5, 1, 5, 3, 5, 56, 8, 5, 1, 6, 1, 6, 1, 6, 1, 7, 1, 7, 3, 7, 63, 8, 7, 1, 8, 1, 8, 1, 9, 1, 9, 1, 9, 1, 9, 1, 9, 1, 9, 4, 9, 73, 8, 9, 11, 9, 12, 9, 74, 1, 9, 4, 9, 78, 8, 9, 11, 9, 12, 9, 79, 1, 9, 5, 9, 83, 8, 9, 10, 9, 12, 9, 86, 9, 9, 1, 10, 1, 10, 1, 11, 1, 11, 1, 11, 1, 11, 1, 12, 1, 12, 1, 12, 1, 12, 1, 12, 1, 12, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 14, 1, 14, 1, 14, 5, 14, 109, 8, 14, 10, 14, 12, 14, 112, 9, 14, 0, 0, 15, 1, 1, 3, 2, 5, 3, 7, 4, 9, 5, 11, 6, 13, 0, 15, 0, 17, 0, 19, 7, 21, 0, 23, 8, 25, 9, 27, 10, 29, 11, 1, 0, 4, 3, 0, 65, 90, 95, 95, 97, 122, 4, 0, 48, 57, 65, 90, 95, 95, 97, 122, 4, 0, 45, 57, 65, 90, 95, 95, 97, 122, 2, 0, 10, 10, 32, 32, 117, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5, 1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0, 23, 1, 0, 0, 0, 0, 25, 1, 0, 0, 0, 0, 27, 1, 0, 0, 0, 0, 29, 1, 0, 0, 0, 1, 31, 1, 0, 0, 0, 3, 33, 1, 0, 0, 0, 5, 38, 1, 0, 0, 0, 7, 41, 1, 0, 0, 0, 9, 51, 1, 0, 0, 0, 11, 55, 1, 0, 0, 0, 13, 57, 1, 0, 0, 0, 15, 62, 1, 0, 0, 0, 17, 64, 1, 0, 0, 0, 19, 66, 1, 0, 0, 0, 21, 87, 1, 0, 0, 0, 23, 89, 1, 0, 0, 0, 25, 93, 1, 0, 0, 0, 27, 99, 1, 0, 0, 0, 29, 105, 1, 0, 0, 0, 31, 32, 5, 91, 0, 0, 32, 2, 1, 0, 0, 0, 33, 34, 5, 93, 0, 0, 34, 4, 1, 0, 0, 0, 35, 39, 5, 61, 0, 0, 36, 37, 5, 63, 0, 0, 37, 39, 5, 61, 0, 0, 38, 35, 1, 0, 0, 0, 38, 36, 1, 0, 0, 0, 39, 6, 1, 0, 0, 0, 40, 42, 7, 0, 0, 0, 41, 40, 1, 0, 0, 0, 41, 42, 1, 0, 0, 0, 42, 46, 1, 0, 0, 0, 43, 45, 7, 1, 0, 0, 44, 43, 1, 0, 0, 0, 45, 48, 1, 0, 0, 0, 46, 44, 1, 0, 0, 0, 46, 47, 1, 0, 0, 0, 47, 49, 1, 0, 0, 0, 48, 46, 1, 0, 0, 0, 49, 50, 5, 36, 0, 0, 50, 8, 1, 0, 0, 0, 51, 52, 5, 35, 0, 0, 52, 10, 1, 0, 0, 0, 53, 56, 3, 13, 6, 0, 54, 56, 9, 0, 0, 0, 55, 53, 1, 0, 0, 0, 55, 54, 1, 0, 0, 0, 56, 12, 1, 0, 0, 0, 57, 58, 3, 17, 8, 0, 58, 59, 3, 15, 7, 0, 59, 14, 1, 0, 0, 0, 60, 63, 3, 1, 0, 0, 61, 63, 3, 17, 8, 0, 62, 60, 1, 0, 0, 0, 62, 61, 1, 0, 0, 0, 63, 16, 1, 0, 0, 0, 64, 65, 5, 92, 0, 0, 65, 18, 1, 0, 0, 0, 66, 67, 5, 58, 0, 0, 67, 68, 5, 117, 0, 0, 68, 69, 5, 115, 0, 0, 69, 70, 5, 101, 0, 0, 70, 72, 1, 0, 0, 0, 71, 73, 3, 21, 10, 0, 72, 71, 1, 0, 0, 0, 73, 74, 1, 0, 0, 0, 74, 72, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 77, 1, 0, 0, 0, 76, 78, 7, 2, 0, 0, 77, 76, 1, 0, 0, 0, 78, 79, 1, 0, 0, 0, 79, 77, 1, 0, 0, 0, 79, 80, 1, 0, 0, 0, 80, 84, 1, 0, 0, 0, 81, 83, 3, 21, 10, 0, 82, 81, 1, 0, 0, 0, 83, 86, 1, 0, 0, 0, 84, 82, 1, 0, 0, 0, 84, 85, 1, 0, 0, 0, 85, 20, 1, 0, 0, 0, 86, 84, 1, 0, 0, 0, 87, 88, 7, 3, 0, 0, 88, 22, 1, 0, 0, 0, 89, 90, 5, 58, 0, 0, 90, 91, 5, 105, 0, 0, 91, 92, 5, 102, 0, 0, 92, 24, 1, 0, 0, 0, 93, 94, 5, 58, 0, 0, 94, 95, 5, 116, 0, 0, 95, 96, 5, 104, 0, 0, 96, 97, 5, 101, 0, 0, 97, 98, 5, 110, 0, 0, 98, 26, 1, 0, 0, 0, 99, 100, 5, 58, 0, 0, 100, 101, 5, 101, 0, 0, 101, 102, 5, 108, 0, 0, 102, 103, 5, 115, 0, 0, 103, 104, 5, 101, 0, 0, 104, 28, 1, 0, 0, 0, 105, 106, 5, 58, 0, 0, 106, 110, 7, 0, 0, 0, 107, 109, 7, 1, 0, 0, 108, 107, 1, 0, 0, 0, 109, 112, 1, 0, 0, 0, 110, 108, 1, 0, 0, 0, 110, 111, 1, 0, 0, 0, 111, 30, 1, 0, 0, 0, 112, 110, 1, 0, 0, 0, 10, 0, 38, 41, 46, 55, 62, 74, 79, 84, 110, 0] \ No newline at end of file diff --git a/python/src/parser/MetaPromptLexer.py b/python/src/parser/MetaPromptLexer.py index 41b6d59..a616b6b 100644 --- a/python/src/parser/MetaPromptLexer.py +++ b/python/src/parser/MetaPromptLexer.py @@ -10,43 +10,45 @@ def serializedATN(): return [ - 4,0,11,110,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, + 4,0,11,113,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, - 13,7,13,2,14,7,14,1,0,1,0,1,1,1,1,1,2,1,2,1,3,3,3,39,8,3,1,3,5,3, - 42,8,3,10,3,12,3,45,9,3,1,3,1,3,1,4,1,4,1,5,1,5,3,5,53,8,5,1,6,1, - 6,1,6,1,7,1,7,3,7,60,8,7,1,8,1,8,1,9,1,9,1,9,1,9,1,9,1,9,4,9,70, - 8,9,11,9,12,9,71,1,9,4,9,75,8,9,11,9,12,9,76,1,9,5,9,80,8,9,10,9, - 12,9,83,9,9,1,10,1,10,1,11,1,11,1,11,1,11,1,12,1,12,1,12,1,12,1, - 12,1,12,1,13,1,13,1,13,1,13,1,13,1,13,1,14,1,14,1,14,5,14,106,8, - 14,10,14,12,14,109,9,14,0,0,15,1,1,3,2,5,3,7,4,9,5,11,6,13,0,15, - 0,17,0,19,7,21,0,23,8,25,9,27,10,29,11,1,0,4,3,0,65,90,95,95,97, - 122,4,0,48,57,65,90,95,95,97,122,4,0,45,57,65,90,95,95,97,122,2, - 0,10,10,32,32,113,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0,0,7,1,0,0, - 0,0,9,1,0,0,0,0,11,1,0,0,0,0,19,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0, - 0,0,27,1,0,0,0,0,29,1,0,0,0,1,31,1,0,0,0,3,33,1,0,0,0,5,35,1,0,0, - 0,7,38,1,0,0,0,9,48,1,0,0,0,11,52,1,0,0,0,13,54,1,0,0,0,15,59,1, - 0,0,0,17,61,1,0,0,0,19,63,1,0,0,0,21,84,1,0,0,0,23,86,1,0,0,0,25, - 90,1,0,0,0,27,96,1,0,0,0,29,102,1,0,0,0,31,32,5,91,0,0,32,2,1,0, - 0,0,33,34,5,93,0,0,34,4,1,0,0,0,35,36,5,61,0,0,36,6,1,0,0,0,37,39, - 7,0,0,0,38,37,1,0,0,0,38,39,1,0,0,0,39,43,1,0,0,0,40,42,7,1,0,0, - 41,40,1,0,0,0,42,45,1,0,0,0,43,41,1,0,0,0,43,44,1,0,0,0,44,46,1, - 0,0,0,45,43,1,0,0,0,46,47,5,36,0,0,47,8,1,0,0,0,48,49,5,35,0,0,49, - 10,1,0,0,0,50,53,3,13,6,0,51,53,9,0,0,0,52,50,1,0,0,0,52,51,1,0, - 0,0,53,12,1,0,0,0,54,55,3,17,8,0,55,56,3,15,7,0,56,14,1,0,0,0,57, - 60,3,1,0,0,58,60,3,17,8,0,59,57,1,0,0,0,59,58,1,0,0,0,60,16,1,0, - 0,0,61,62,5,92,0,0,62,18,1,0,0,0,63,64,5,58,0,0,64,65,5,117,0,0, - 65,66,5,115,0,0,66,67,5,101,0,0,67,69,1,0,0,0,68,70,3,21,10,0,69, - 68,1,0,0,0,70,71,1,0,0,0,71,69,1,0,0,0,71,72,1,0,0,0,72,74,1,0,0, - 0,73,75,7,2,0,0,74,73,1,0,0,0,75,76,1,0,0,0,76,74,1,0,0,0,76,77, - 1,0,0,0,77,81,1,0,0,0,78,80,3,21,10,0,79,78,1,0,0,0,80,83,1,0,0, - 0,81,79,1,0,0,0,81,82,1,0,0,0,82,20,1,0,0,0,83,81,1,0,0,0,84,85, - 7,3,0,0,85,22,1,0,0,0,86,87,5,58,0,0,87,88,5,105,0,0,88,89,5,102, - 0,0,89,24,1,0,0,0,90,91,5,58,0,0,91,92,5,116,0,0,92,93,5,104,0,0, - 93,94,5,101,0,0,94,95,5,110,0,0,95,26,1,0,0,0,96,97,5,58,0,0,97, - 98,5,101,0,0,98,99,5,108,0,0,99,100,5,115,0,0,100,101,5,101,0,0, - 101,28,1,0,0,0,102,103,5,58,0,0,103,107,7,0,0,0,104,106,7,1,0,0, - 105,104,1,0,0,0,106,109,1,0,0,0,107,105,1,0,0,0,107,108,1,0,0,0, - 108,30,1,0,0,0,109,107,1,0,0,0,9,0,38,43,52,59,71,76,81,107,0 + 13,7,13,2,14,7,14,1,0,1,0,1,1,1,1,1,2,1,2,1,2,3,2,39,8,2,1,3,3,3, + 42,8,3,1,3,5,3,45,8,3,10,3,12,3,48,9,3,1,3,1,3,1,4,1,4,1,5,1,5,3, + 5,56,8,5,1,6,1,6,1,6,1,7,1,7,3,7,63,8,7,1,8,1,8,1,9,1,9,1,9,1,9, + 1,9,1,9,4,9,73,8,9,11,9,12,9,74,1,9,4,9,78,8,9,11,9,12,9,79,1,9, + 5,9,83,8,9,10,9,12,9,86,9,9,1,10,1,10,1,11,1,11,1,11,1,11,1,12,1, + 12,1,12,1,12,1,12,1,12,1,13,1,13,1,13,1,13,1,13,1,13,1,14,1,14,1, + 14,5,14,109,8,14,10,14,12,14,112,9,14,0,0,15,1,1,3,2,5,3,7,4,9,5, + 11,6,13,0,15,0,17,0,19,7,21,0,23,8,25,9,27,10,29,11,1,0,4,3,0,65, + 90,95,95,97,122,4,0,48,57,65,90,95,95,97,122,4,0,45,57,65,90,95, + 95,97,122,2,0,10,10,32,32,117,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0, + 0,0,7,1,0,0,0,0,9,1,0,0,0,0,11,1,0,0,0,0,19,1,0,0,0,0,23,1,0,0,0, + 0,25,1,0,0,0,0,27,1,0,0,0,0,29,1,0,0,0,1,31,1,0,0,0,3,33,1,0,0,0, + 5,38,1,0,0,0,7,41,1,0,0,0,9,51,1,0,0,0,11,55,1,0,0,0,13,57,1,0,0, + 0,15,62,1,0,0,0,17,64,1,0,0,0,19,66,1,0,0,0,21,87,1,0,0,0,23,89, + 1,0,0,0,25,93,1,0,0,0,27,99,1,0,0,0,29,105,1,0,0,0,31,32,5,91,0, + 0,32,2,1,0,0,0,33,34,5,93,0,0,34,4,1,0,0,0,35,39,5,61,0,0,36,37, + 5,63,0,0,37,39,5,61,0,0,38,35,1,0,0,0,38,36,1,0,0,0,39,6,1,0,0,0, + 40,42,7,0,0,0,41,40,1,0,0,0,41,42,1,0,0,0,42,46,1,0,0,0,43,45,7, + 1,0,0,44,43,1,0,0,0,45,48,1,0,0,0,46,44,1,0,0,0,46,47,1,0,0,0,47, + 49,1,0,0,0,48,46,1,0,0,0,49,50,5,36,0,0,50,8,1,0,0,0,51,52,5,35, + 0,0,52,10,1,0,0,0,53,56,3,13,6,0,54,56,9,0,0,0,55,53,1,0,0,0,55, + 54,1,0,0,0,56,12,1,0,0,0,57,58,3,17,8,0,58,59,3,15,7,0,59,14,1,0, + 0,0,60,63,3,1,0,0,61,63,3,17,8,0,62,60,1,0,0,0,62,61,1,0,0,0,63, + 16,1,0,0,0,64,65,5,92,0,0,65,18,1,0,0,0,66,67,5,58,0,0,67,68,5,117, + 0,0,68,69,5,115,0,0,69,70,5,101,0,0,70,72,1,0,0,0,71,73,3,21,10, + 0,72,71,1,0,0,0,73,74,1,0,0,0,74,72,1,0,0,0,74,75,1,0,0,0,75,77, + 1,0,0,0,76,78,7,2,0,0,77,76,1,0,0,0,78,79,1,0,0,0,79,77,1,0,0,0, + 79,80,1,0,0,0,80,84,1,0,0,0,81,83,3,21,10,0,82,81,1,0,0,0,83,86, + 1,0,0,0,84,82,1,0,0,0,84,85,1,0,0,0,85,20,1,0,0,0,86,84,1,0,0,0, + 87,88,7,3,0,0,88,22,1,0,0,0,89,90,5,58,0,0,90,91,5,105,0,0,91,92, + 5,102,0,0,92,24,1,0,0,0,93,94,5,58,0,0,94,95,5,116,0,0,95,96,5,104, + 0,0,96,97,5,101,0,0,97,98,5,110,0,0,98,26,1,0,0,0,99,100,5,58,0, + 0,100,101,5,101,0,0,101,102,5,108,0,0,102,103,5,115,0,0,103,104, + 5,101,0,0,104,28,1,0,0,0,105,106,5,58,0,0,106,110,7,0,0,0,107,109, + 7,1,0,0,108,107,1,0,0,0,109,112,1,0,0,0,110,108,1,0,0,0,110,111, + 1,0,0,0,111,30,1,0,0,0,112,110,1,0,0,0,10,0,38,41,46,55,62,74,79, + 84,110,0 ] class MetaPromptLexer(Lexer): @@ -72,7 +74,7 @@ class MetaPromptLexer(Lexer): modeNames = [ "DEFAULT_MODE" ] literalNames = [ "", - "'['", "']'", "'='", "'#'", "':if'", "':then'", "':else'" ] + "'['", "']'", "'#'", "':if'", "':then'", "':else'" ] symbolicNames = [ "", "LB", "RB", "EQ_KW", "META_PROMPT", "COMMENT_KW", "CHAR", "USE", diff --git a/python/src/parser/MetaPromptLexer.tokens b/python/src/parser/MetaPromptLexer.tokens index d7c421c..c96cb56 100644 --- a/python/src/parser/MetaPromptLexer.tokens +++ b/python/src/parser/MetaPromptLexer.tokens @@ -11,7 +11,6 @@ ELSE_KW=10 VAR_NAME=11 '['=1 ']'=2 -'='=3 '#'=5 ':if'=8 ':then'=9 diff --git a/python/src/parser/MetaPromptParser.py b/python/src/parser/MetaPromptParser.py index 52f7703..3262638 100644 --- a/python/src/parser/MetaPromptParser.py +++ b/python/src/parser/MetaPromptParser.py @@ -48,8 +48,9 @@ class MetaPromptParser ( Parser ): sharedContextCache = PredictionContextCache() - literalNames = [ "", "'['", "']'", "'='", "", "'#'", - "", "", "':if'", "':then'", "':else'" ] + literalNames = [ "", "'['", "']'", "", "", + "'#'", "", "", "':if'", "':then'", + "':else'" ] symbolicNames = [ "", "LB", "RB", "EQ_KW", "META_PROMPT", "COMMENT_KW", "CHAR", "USE", "IF_KW", "THEN_KW", "ELSE_KW", "VAR_NAME" ] diff --git a/python/tests/test_loader.py b/python/tests/test_loader.py index 4689f55..5b25905 100644 --- a/python/tests/test_loader.py +++ b/python/tests/test_loader.py @@ -1,7 +1,11 @@ -from loader import extract_variables +from loader import extract_parameter_set, ParameterSet from parse_metaprompt import parse_metaprompt, extract_tokens +def extract_variables(ast): + return extract_parameter_set(ast).required + + def test_extractor_1(): prompt = """ [:foo] @@ -35,3 +39,118 @@ def test_extractor_assign_3(): [:foo][:foo=baz] - first used, then assigned """ assert extract_variables(parse_metaprompt(prompt)) == set(["foo"]) + + +def test_extractor_assign_4(): + prompt = """ + [:foo][:foo=baz] - first used, then assigned + """ + assert extract_variables(parse_metaprompt(prompt)) == set(["foo"]) + + +def test_extractor_assign_5(): + prompt = """ + [:foo][:foo=baz] - first used, then assigned + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set(["foo"]) + assert res.optional == set() + assert res.assigned == set(["foo"]) + + +def test_extractor_assign_6(): + prompt = """ + [:foo?=default] + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set() + assert res.optional == set(["foo"]) + assert res.assigned == set(["foo"]) + + +def test_extractor_if(): + prompt = """ + [:if [:foo=bar] :then [:foo] :else [:foo]] + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set() + assert res.optional == set() + assert res.assigned == set(["foo"]) + + +def test_extractor_if_single_branch_assign(): + prompt = """ + [:if ... :then [:foo=bar] :else [:foo]] + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set(["foo"]) + assert res.optional == set() + assert res.assigned == set() + + +def test_extractor_if_both_branches_assign(): + prompt = """ + [:if ... :then [:foo=bar] :else [:foo=baz]] + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set() + assert res.optional == set() + assert res.assigned == set(["foo"]) + + +def test_extractor_if_both_branches_require(): + prompt = """ + [:if ... :then [:foo] :else [:foo]] + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set(["foo"]) + assert res.optional == set() + assert res.assigned == set() + + +def test_extractor_if_single_branch_requires(): + prompt = """ + [:if ... :then [:foo] :else ...] + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set(["foo"]) + assert res.optional == set() + assert res.assigned == set() + + prompt = """ + [:if ... :then ... :else [:foo]] + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set(["foo"]) + assert res.optional == set() + assert res.assigned == set() + + +def test_extractor_if_single_branch_optionally_assigns(): + prompt = """ + [:if ... :then [:foo?=bar] :else ...] + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set() + assert res.optional == set(["foo"]) + assert res.assigned == set() + + +def test_extractor_if_both_branches_assign_var(): + prompt = """ + [:if ... :then [:foo=bar] :else [:foo=bar]][:foo] + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set() + assert res.optional == set() + assert res.assigned == set(["foo"]) + + +def test_extractor_if_single_branch_assigns_var(): + prompt = """ + [:if ... :then [:foo=bar] :else ...][:foo] <- required + """ + res = extract_parameter_set(parse_metaprompt(prompt)) + assert res.required == set(["foo"]) + assert res.optional == set() + assert res.assigned == set() diff --git a/python/tests/test_parser.py b/python/tests/test_parser.py index 7f9f032..401c0f3 100644 --- a/python/tests/test_parser.py +++ b/python/tests/test_parser.py @@ -37,6 +37,15 @@ def use(module_name, parameters): return {"type": "use", "module_name": module_name, "parameters": parameters} +def assign(name, exprs, required=True): + return { + "type": "assign", + "required": required, + "name": name, + "exprs": exprs, + } + + def test_empty(): result = parse_metaprompt("") assert result["exprs"] == [] @@ -219,31 +228,33 @@ def test_meta_dollar2(): def test_assign(): result = parse_metaprompt("[:foo=bar]") + assert result["exprs"] == [assign("foo", [{"type": "text", "text": "bar"}])] + + +def test_assign_optional(): + result = parse_metaprompt("[:foo?=bar]") assert result["exprs"] == [ - { - "type": "assign", - "name": "foo", - "exprs": [{"type": "text", "text": "bar"}], - } + assign("foo", [{"type": "text", "text": "bar"}], required=False) ] def test_assign_trailing_bracket(): result = parse_metaprompt("[:foo=bar]]") assert result["exprs"] == [ - { - "type": "assign", - "name": "foo", - "exprs": [{"type": "text", "text": "bar"}], - }, + assign("foo", [{"type": "text", "text": "bar"}]), t("]"), ] -def test_assign_trailing_bracket(): +def test_assign_normal(): result = parse_metaprompt("[:foo=[$ hi ]]") assert result["exprs"] == [ - {"type": "assign", "name": "foo", "exprs": [meta([t(" hi ")])]} + { + "type": "assign", + "required": True, + "name": "foo", + "exprs": [meta([t(" hi ")])], + } ]