Skip to content

Commit

Permalink
Implement optional assignments and proper parameter extraction for if…
Browse files Browse the repository at this point in the history
… exprs
  • Loading branch information
klntsky committed Nov 19, 2024
1 parent abfd268 commit fb6da7d
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 102 deletions.
2 changes: 1 addition & 1 deletion grammar/MetaPrompt.g4
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ text: CHAR+ ;

LB : '[';
RB : ']';
EQ_KW : '=' ;
EQ_KW : '=' | '?=' ;
META_PROMPT : [a-zA-Z_]?[a-zA-Z0-9_]* '$' ;
COMMENT_KW : '#' ;
CHAR : ( ESCAPED | .);
Expand Down
29 changes: 15 additions & 14 deletions python/src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from env import Env
from runtime import BaseRuntime
from typing import AsyncGenerator, List
from loader import extract_variables
from loader import extract_parameter_set
from eval_utils.assignment import Assignment
from eval_utils.chat_history import serialize_chat_history

Expand Down Expand Up @@ -86,7 +86,7 @@ async def _eval_ast(ast):
parameters = ast["parameters"]
module_name = ast["module_name"]
loaded_ast = runtime.load_module(module_name)
required_variables = extract_variables(loaded_ast)
required_variables = extract_parameter_set(loaded_ast).required
for required in required_variables:
if required not in parameters:
raise ImportError(
Expand Down Expand Up @@ -117,18 +117,19 @@ async def _eval_ast(ast):
elif ast["type"] == "assign":
var_name = ast["name"]
value = (await _collect_exprs(ast["exprs"])).strip()
if var_name == "STATUS":
runtime.set_status(value)
elif var_name == "ROLE":
if value not in ALLOWED_ROLES:
raise ValueError(
"ROLE variable must be one of "
+ "".join([f"'{role}', " for role in ALLOWED_ROLES])
+ ", you specified: "
+ value
)
yield Assignment("ROLE", value)
env.set(var_name, value)
if ast["required"] or env.get(var_name) is None:
if var_name == "STATUS":
runtime.set_status(value)
elif var_name == "ROLE":
if value not in ALLOWED_ROLES:
raise ValueError(
"ROLE variable must be one of "
+ "".join([f"'{role}', " for role in ALLOWED_ROLES])
+ ", you specified: "
+ value
)
yield Assignment("ROLE", value)
env.set(var_name, value)
elif ast["type"] == "meta":
# Load chat history
chat_id = ast["chat"] if "chat" in ast else None
Expand Down
150 changes: 120 additions & 30 deletions python/src/loader.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,122 @@
def _discover_variables(ast):
from typing import Set


class ParameterSet:
def __init__(
self,
required: Set[str] = None,
optional: Set[str] = None,
assigned: Set[str] = None,
):
if required is None:
required = set()
if optional is None:
optional = set()
if assigned is None:
assigned = set()
self.required = required
self.optional = optional
self.assigned = assigned

def __eq__(self, other):
return (
self.required == other.required
and self.optional == other.optional
and self.assigned == other.assigned
)

def then(self, other):
"""Implements sequential composition of ParameterSets:
params([expr1, expr2]) = params(expr1).then(params(expr2))
"""
return ParameterSet(
# required
self.required.union(other.required.difference(self.assigned)),
# optional
self.optional.union(other.optional)
.difference(self.required)
.difference(other.required),
# assigned
self.assigned.union(other.assigned),
)

def alternative(self, other):
"""Implements parallel composition of ParameterSets, that is used
for :if: either of the alternatives can be executed, so we have to be
conservative.
"""
return ParameterSet(
# required: all of the required variables are required,
# we don't know which branch will be chosen
self.required.union(other.required),
# optional: if something is optional in just one of the branches,
# it is optional in the whole expression, but if it required in
# either, it is required in the whole
self.optional.union(other.optional)
.difference(self.required)
.difference(other.required),
# assigned: something must be assigned in both branches for us to
# be sure it is assigned
self.assigned.intersection(other.assigned),
)

def assign_var(self, name):
"""handles [:name=value]"""
self.assigned.add(name)

def assign_var_optional(self, name):
"""handles [:name?=value]"""
if name not in self.assigned:
self.optional.add(name)
self.assigned.add(name)

def use_var(self, name):
if name not in self.assigned:
self.required.add(name)


def extract_parameter_set(ast):
# TODO: special handling of ROLE, MODEL variables
res = ParameterSet()

if isinstance(ast, list):
for node in ast:
yield from _discover_variables(node)
elif isinstance(ast, dict):
if "type" in ast:
# TODO: evaluate both :if branches in parallel, to cover this case:
# [:if foo :then [:bar=baz] :else [:bar]]
# -- [:bar] should be unbound here, because it is unbound in the
# first branch
if ast["type"] == "comment":
return
elif ast["type"] == "var":
if ast["name"] != "MODEL":
yield {"type": "var", "name": ast["name"]}
elif ast["type"] == "assign":
yield {"type": "assign", "name": ast["name"]}
for key in ast:
yield from _discover_variables(ast[key])


def extract_variables(ast):
variables = set()
assigned = set()
for item in _discover_variables(ast):
match item:
case {"name": name, "type": "var"}:
if name not in assigned:
variables.add(name)
case {"name": name, "type": "assign"}:
assigned.add(name)
return variables
res = res.then(extract_parameter_set(node))
elif isinstance(ast, dict) and "type" in ast:
if ast["type"] == "text":
pass
elif ast["type"] == "metaprompt":
for expr in ast["exprs"]:
res = res.then(extract_parameter_set(expr))
elif ast["type"] == "var":
res.use_var(ast["name"])
elif ast["type"] == "use":
for _, expr in ast["exprs"]:
res = res.then(extract_parameter_set(expr))
elif ast["type"] == "assign":
if ast["required"]:
res.assign_var(ast["name"])
else:
res.assign_var_optional(ast["name"])
elif ast["type"] == "meta":
for expr in ast["exprs"]:
res = res.then(extract_parameter_set(expr))
elif ast["type"] == "exprs":
for expr in ast["exprs"]:
extract_parameter_set(expr, assigned)
elif ast["type"] == "if_then_else":
res = res.then(extract_parameter_set(ast["condition"])).then(
extract_parameter_set(ast["then"]).alternative(
extract_parameter_set(ast["else"])
)
)
else:
raise ValueError(
"extract_parameter_set: unknown AST expression: " + str(ast)
)
else:
raise ValueError(
"extract_parameter_set: unknown AST expression: " + str(ast)
)

return res
2 changes: 2 additions & 0 deletions python/src/parse_metaprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,10 @@ def visitMeta_body(self, ctx: MetaPromptParser.Meta_bodyContext):
for expr in ctx.exprs():
expr_items = self.visit(expr)
exprs.extend(expr_items)
required = ctx.EQ_KW().getText() == "=" # or "?="
return {
"type": "assign",
"required": required,
"name": var_name,
"exprs": _join_text_pieces(exprs),
}
Expand Down
2 changes: 1 addition & 1 deletion python/src/parser/MetaPrompt.interp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ token literal names:
null
'['
']'
'='
null
null
'#'
null
Expand Down
1 change: 0 additions & 1 deletion python/src/parser/MetaPrompt.tokens
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ ELSE_KW=10
VAR_NAME=11
'['=1
']'=2
'='=3
'#'=5
':if'=8
':then'=9
Expand Down
4 changes: 2 additions & 2 deletions python/src/parser/MetaPromptLexer.interp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ token literal names:
null
'['
']'
'='
null
null
'#'
null
Expand Down Expand Up @@ -51,4 +51,4 @@ mode names:
DEFAULT_MODE

atn:
[4, 0, 11, 110, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 3, 3, 39, 8, 3, 1, 3, 5, 3, 42, 8, 3, 10, 3, 12, 3, 45, 9, 3, 1, 3, 1, 3, 1, 4, 1, 4, 1, 5, 1, 5, 3, 5, 53, 8, 5, 1, 6, 1, 6, 1, 6, 1, 7, 1, 7, 3, 7, 60, 8, 7, 1, 8, 1, 8, 1, 9, 1, 9, 1, 9, 1, 9, 1, 9, 1, 9, 4, 9, 70, 8, 9, 11, 9, 12, 9, 71, 1, 9, 4, 9, 75, 8, 9, 11, 9, 12, 9, 76, 1, 9, 5, 9, 80, 8, 9, 10, 9, 12, 9, 83, 9, 9, 1, 10, 1, 10, 1, 11, 1, 11, 1, 11, 1, 11, 1, 12, 1, 12, 1, 12, 1, 12, 1, 12, 1, 12, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 14, 1, 14, 1, 14, 5, 14, 106, 8, 14, 10, 14, 12, 14, 109, 9, 14, 0, 0, 15, 1, 1, 3, 2, 5, 3, 7, 4, 9, 5, 11, 6, 13, 0, 15, 0, 17, 0, 19, 7, 21, 0, 23, 8, 25, 9, 27, 10, 29, 11, 1, 0, 4, 3, 0, 65, 90, 95, 95, 97, 122, 4, 0, 48, 57, 65, 90, 95, 95, 97, 122, 4, 0, 45, 57, 65, 90, 95, 95, 97, 122, 2, 0, 10, 10, 32, 32, 113, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5, 1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0, 23, 1, 0, 0, 0, 0, 25, 1, 0, 0, 0, 0, 27, 1, 0, 0, 0, 0, 29, 1, 0, 0, 0, 1, 31, 1, 0, 0, 0, 3, 33, 1, 0, 0, 0, 5, 35, 1, 0, 0, 0, 7, 38, 1, 0, 0, 0, 9, 48, 1, 0, 0, 0, 11, 52, 1, 0, 0, 0, 13, 54, 1, 0, 0, 0, 15, 59, 1, 0, 0, 0, 17, 61, 1, 0, 0, 0, 19, 63, 1, 0, 0, 0, 21, 84, 1, 0, 0, 0, 23, 86, 1, 0, 0, 0, 25, 90, 1, 0, 0, 0, 27, 96, 1, 0, 0, 0, 29, 102, 1, 0, 0, 0, 31, 32, 5, 91, 0, 0, 32, 2, 1, 0, 0, 0, 33, 34, 5, 93, 0, 0, 34, 4, 1, 0, 0, 0, 35, 36, 5, 61, 0, 0, 36, 6, 1, 0, 0, 0, 37, 39, 7, 0, 0, 0, 38, 37, 1, 0, 0, 0, 38, 39, 1, 0, 0, 0, 39, 43, 1, 0, 0, 0, 40, 42, 7, 1, 0, 0, 41, 40, 1, 0, 0, 0, 42, 45, 1, 0, 0, 0, 43, 41, 1, 0, 0, 0, 43, 44, 1, 0, 0, 0, 44, 46, 1, 0, 0, 0, 45, 43, 1, 0, 0, 0, 46, 47, 5, 36, 0, 0, 47, 8, 1, 0, 0, 0, 48, 49, 5, 35, 0, 0, 49, 10, 1, 0, 0, 0, 50, 53, 3, 13, 6, 0, 51, 53, 9, 0, 0, 0, 52, 50, 1, 0, 0, 0, 52, 51, 1, 0, 0, 0, 53, 12, 1, 0, 0, 0, 54, 55, 3, 17, 8, 0, 55, 56, 3, 15, 7, 0, 56, 14, 1, 0, 0, 0, 57, 60, 3, 1, 0, 0, 58, 60, 3, 17, 8, 0, 59, 57, 1, 0, 0, 0, 59, 58, 1, 0, 0, 0, 60, 16, 1, 0, 0, 0, 61, 62, 5, 92, 0, 0, 62, 18, 1, 0, 0, 0, 63, 64, 5, 58, 0, 0, 64, 65, 5, 117, 0, 0, 65, 66, 5, 115, 0, 0, 66, 67, 5, 101, 0, 0, 67, 69, 1, 0, 0, 0, 68, 70, 3, 21, 10, 0, 69, 68, 1, 0, 0, 0, 70, 71, 1, 0, 0, 0, 71, 69, 1, 0, 0, 0, 71, 72, 1, 0, 0, 0, 72, 74, 1, 0, 0, 0, 73, 75, 7, 2, 0, 0, 74, 73, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 74, 1, 0, 0, 0, 76, 77, 1, 0, 0, 0, 77, 81, 1, 0, 0, 0, 78, 80, 3, 21, 10, 0, 79, 78, 1, 0, 0, 0, 80, 83, 1, 0, 0, 0, 81, 79, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 20, 1, 0, 0, 0, 83, 81, 1, 0, 0, 0, 84, 85, 7, 3, 0, 0, 85, 22, 1, 0, 0, 0, 86, 87, 5, 58, 0, 0, 87, 88, 5, 105, 0, 0, 88, 89, 5, 102, 0, 0, 89, 24, 1, 0, 0, 0, 90, 91, 5, 58, 0, 0, 91, 92, 5, 116, 0, 0, 92, 93, 5, 104, 0, 0, 93, 94, 5, 101, 0, 0, 94, 95, 5, 110, 0, 0, 95, 26, 1, 0, 0, 0, 96, 97, 5, 58, 0, 0, 97, 98, 5, 101, 0, 0, 98, 99, 5, 108, 0, 0, 99, 100, 5, 115, 0, 0, 100, 101, 5, 101, 0, 0, 101, 28, 1, 0, 0, 0, 102, 103, 5, 58, 0, 0, 103, 107, 7, 0, 0, 0, 104, 106, 7, 1, 0, 0, 105, 104, 1, 0, 0, 0, 106, 109, 1, 0, 0, 0, 107, 105, 1, 0, 0, 0, 107, 108, 1, 0, 0, 0, 108, 30, 1, 0, 0, 0, 109, 107, 1, 0, 0, 0, 9, 0, 38, 43, 52, 59, 71, 76, 81, 107, 0]
[4, 0, 11, 113, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 3, 2, 39, 8, 2, 1, 3, 3, 3, 42, 8, 3, 1, 3, 5, 3, 45, 8, 3, 10, 3, 12, 3, 48, 9, 3, 1, 3, 1, 3, 1, 4, 1, 4, 1, 5, 1, 5, 3, 5, 56, 8, 5, 1, 6, 1, 6, 1, 6, 1, 7, 1, 7, 3, 7, 63, 8, 7, 1, 8, 1, 8, 1, 9, 1, 9, 1, 9, 1, 9, 1, 9, 1, 9, 4, 9, 73, 8, 9, 11, 9, 12, 9, 74, 1, 9, 4, 9, 78, 8, 9, 11, 9, 12, 9, 79, 1, 9, 5, 9, 83, 8, 9, 10, 9, 12, 9, 86, 9, 9, 1, 10, 1, 10, 1, 11, 1, 11, 1, 11, 1, 11, 1, 12, 1, 12, 1, 12, 1, 12, 1, 12, 1, 12, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 14, 1, 14, 1, 14, 5, 14, 109, 8, 14, 10, 14, 12, 14, 112, 9, 14, 0, 0, 15, 1, 1, 3, 2, 5, 3, 7, 4, 9, 5, 11, 6, 13, 0, 15, 0, 17, 0, 19, 7, 21, 0, 23, 8, 25, 9, 27, 10, 29, 11, 1, 0, 4, 3, 0, 65, 90, 95, 95, 97, 122, 4, 0, 48, 57, 65, 90, 95, 95, 97, 122, 4, 0, 45, 57, 65, 90, 95, 95, 97, 122, 2, 0, 10, 10, 32, 32, 117, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5, 1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0, 23, 1, 0, 0, 0, 0, 25, 1, 0, 0, 0, 0, 27, 1, 0, 0, 0, 0, 29, 1, 0, 0, 0, 1, 31, 1, 0, 0, 0, 3, 33, 1, 0, 0, 0, 5, 38, 1, 0, 0, 0, 7, 41, 1, 0, 0, 0, 9, 51, 1, 0, 0, 0, 11, 55, 1, 0, 0, 0, 13, 57, 1, 0, 0, 0, 15, 62, 1, 0, 0, 0, 17, 64, 1, 0, 0, 0, 19, 66, 1, 0, 0, 0, 21, 87, 1, 0, 0, 0, 23, 89, 1, 0, 0, 0, 25, 93, 1, 0, 0, 0, 27, 99, 1, 0, 0, 0, 29, 105, 1, 0, 0, 0, 31, 32, 5, 91, 0, 0, 32, 2, 1, 0, 0, 0, 33, 34, 5, 93, 0, 0, 34, 4, 1, 0, 0, 0, 35, 39, 5, 61, 0, 0, 36, 37, 5, 63, 0, 0, 37, 39, 5, 61, 0, 0, 38, 35, 1, 0, 0, 0, 38, 36, 1, 0, 0, 0, 39, 6, 1, 0, 0, 0, 40, 42, 7, 0, 0, 0, 41, 40, 1, 0, 0, 0, 41, 42, 1, 0, 0, 0, 42, 46, 1, 0, 0, 0, 43, 45, 7, 1, 0, 0, 44, 43, 1, 0, 0, 0, 45, 48, 1, 0, 0, 0, 46, 44, 1, 0, 0, 0, 46, 47, 1, 0, 0, 0, 47, 49, 1, 0, 0, 0, 48, 46, 1, 0, 0, 0, 49, 50, 5, 36, 0, 0, 50, 8, 1, 0, 0, 0, 51, 52, 5, 35, 0, 0, 52, 10, 1, 0, 0, 0, 53, 56, 3, 13, 6, 0, 54, 56, 9, 0, 0, 0, 55, 53, 1, 0, 0, 0, 55, 54, 1, 0, 0, 0, 56, 12, 1, 0, 0, 0, 57, 58, 3, 17, 8, 0, 58, 59, 3, 15, 7, 0, 59, 14, 1, 0, 0, 0, 60, 63, 3, 1, 0, 0, 61, 63, 3, 17, 8, 0, 62, 60, 1, 0, 0, 0, 62, 61, 1, 0, 0, 0, 63, 16, 1, 0, 0, 0, 64, 65, 5, 92, 0, 0, 65, 18, 1, 0, 0, 0, 66, 67, 5, 58, 0, 0, 67, 68, 5, 117, 0, 0, 68, 69, 5, 115, 0, 0, 69, 70, 5, 101, 0, 0, 70, 72, 1, 0, 0, 0, 71, 73, 3, 21, 10, 0, 72, 71, 1, 0, 0, 0, 73, 74, 1, 0, 0, 0, 74, 72, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 77, 1, 0, 0, 0, 76, 78, 7, 2, 0, 0, 77, 76, 1, 0, 0, 0, 78, 79, 1, 0, 0, 0, 79, 77, 1, 0, 0, 0, 79, 80, 1, 0, 0, 0, 80, 84, 1, 0, 0, 0, 81, 83, 3, 21, 10, 0, 82, 81, 1, 0, 0, 0, 83, 86, 1, 0, 0, 0, 84, 82, 1, 0, 0, 0, 84, 85, 1, 0, 0, 0, 85, 20, 1, 0, 0, 0, 86, 84, 1, 0, 0, 0, 87, 88, 7, 3, 0, 0, 88, 22, 1, 0, 0, 0, 89, 90, 5, 58, 0, 0, 90, 91, 5, 105, 0, 0, 91, 92, 5, 102, 0, 0, 92, 24, 1, 0, 0, 0, 93, 94, 5, 58, 0, 0, 94, 95, 5, 116, 0, 0, 95, 96, 5, 104, 0, 0, 96, 97, 5, 101, 0, 0, 97, 98, 5, 110, 0, 0, 98, 26, 1, 0, 0, 0, 99, 100, 5, 58, 0, 0, 100, 101, 5, 101, 0, 0, 101, 102, 5, 108, 0, 0, 102, 103, 5, 115, 0, 0, 103, 104, 5, 101, 0, 0, 104, 28, 1, 0, 0, 0, 105, 106, 5, 58, 0, 0, 106, 110, 7, 0, 0, 0, 107, 109, 7, 1, 0, 0, 108, 107, 1, 0, 0, 0, 109, 112, 1, 0, 0, 0, 110, 108, 1, 0, 0, 0, 110, 111, 1, 0, 0, 0, 111, 30, 1, 0, 0, 0, 112, 110, 1, 0, 0, 0, 10, 0, 38, 41, 46, 55, 62, 74, 79, 84, 110, 0]
Loading

0 comments on commit fb6da7d

Please sign in to comment.