-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement optional assignments and proper parameter extraction for if…
… exprs
- Loading branch information
Showing
12 changed files
with
326 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,122 @@ | ||
def _discover_variables(ast): | ||
from typing import Set | ||
|
||
|
||
class ParameterSet: | ||
def __init__( | ||
self, | ||
required: Set[str] = None, | ||
optional: Set[str] = None, | ||
assigned: Set[str] = None, | ||
): | ||
if required is None: | ||
required = set() | ||
if optional is None: | ||
optional = set() | ||
if assigned is None: | ||
assigned = set() | ||
self.required = required | ||
self.optional = optional | ||
self.assigned = assigned | ||
|
||
def __eq__(self, other): | ||
return ( | ||
self.required == other.required | ||
and self.optional == other.optional | ||
and self.assigned == other.assigned | ||
) | ||
|
||
def then(self, other): | ||
"""Implements sequential composition of ParameterSets: | ||
params([expr1, expr2]) = params(expr1).then(params(expr2)) | ||
""" | ||
return ParameterSet( | ||
# required | ||
self.required.union(other.required.difference(self.assigned)), | ||
# optional | ||
self.optional.union(other.optional) | ||
.difference(self.required) | ||
.difference(other.required), | ||
# assigned | ||
self.assigned.union(other.assigned), | ||
) | ||
|
||
def alternative(self, other): | ||
"""Implements parallel composition of ParameterSets, that is used | ||
for :if: either of the alternatives can be executed, so we have to be | ||
conservative. | ||
""" | ||
return ParameterSet( | ||
# required: all of the required variables are required, | ||
# we don't know which branch will be chosen | ||
self.required.union(other.required), | ||
# optional: if something is optional in just one of the branches, | ||
# it is optional in the whole expression, but if it required in | ||
# either, it is required in the whole | ||
self.optional.union(other.optional) | ||
.difference(self.required) | ||
.difference(other.required), | ||
# assigned: something must be assigned in both branches for us to | ||
# be sure it is assigned | ||
self.assigned.intersection(other.assigned), | ||
) | ||
|
||
def assign_var(self, name): | ||
"""handles [:name=value]""" | ||
self.assigned.add(name) | ||
|
||
def assign_var_optional(self, name): | ||
"""handles [:name?=value]""" | ||
if name not in self.assigned: | ||
self.optional.add(name) | ||
self.assigned.add(name) | ||
|
||
def use_var(self, name): | ||
if name not in self.assigned: | ||
self.required.add(name) | ||
|
||
|
||
def extract_parameter_set(ast): | ||
# TODO: special handling of ROLE, MODEL variables | ||
res = ParameterSet() | ||
|
||
if isinstance(ast, list): | ||
for node in ast: | ||
yield from _discover_variables(node) | ||
elif isinstance(ast, dict): | ||
if "type" in ast: | ||
# TODO: evaluate both :if branches in parallel, to cover this case: | ||
# [:if foo :then [:bar=baz] :else [:bar]] | ||
# -- [:bar] should be unbound here, because it is unbound in the | ||
# first branch | ||
if ast["type"] == "comment": | ||
return | ||
elif ast["type"] == "var": | ||
if ast["name"] != "MODEL": | ||
yield {"type": "var", "name": ast["name"]} | ||
elif ast["type"] == "assign": | ||
yield {"type": "assign", "name": ast["name"]} | ||
for key in ast: | ||
yield from _discover_variables(ast[key]) | ||
|
||
|
||
def extract_variables(ast): | ||
variables = set() | ||
assigned = set() | ||
for item in _discover_variables(ast): | ||
match item: | ||
case {"name": name, "type": "var"}: | ||
if name not in assigned: | ||
variables.add(name) | ||
case {"name": name, "type": "assign"}: | ||
assigned.add(name) | ||
return variables | ||
res = res.then(extract_parameter_set(node)) | ||
elif isinstance(ast, dict) and "type" in ast: | ||
if ast["type"] == "text": | ||
pass | ||
elif ast["type"] == "metaprompt": | ||
for expr in ast["exprs"]: | ||
res = res.then(extract_parameter_set(expr)) | ||
elif ast["type"] == "var": | ||
res.use_var(ast["name"]) | ||
elif ast["type"] == "use": | ||
for _, expr in ast["exprs"]: | ||
res = res.then(extract_parameter_set(expr)) | ||
elif ast["type"] == "assign": | ||
if ast["required"]: | ||
res.assign_var(ast["name"]) | ||
else: | ||
res.assign_var_optional(ast["name"]) | ||
elif ast["type"] == "meta": | ||
for expr in ast["exprs"]: | ||
res = res.then(extract_parameter_set(expr)) | ||
elif ast["type"] == "exprs": | ||
for expr in ast["exprs"]: | ||
extract_parameter_set(expr, assigned) | ||
elif ast["type"] == "if_then_else": | ||
res = res.then(extract_parameter_set(ast["condition"])).then( | ||
extract_parameter_set(ast["then"]).alternative( | ||
extract_parameter_set(ast["else"]) | ||
) | ||
) | ||
else: | ||
raise ValueError( | ||
"extract_parameter_set: unknown AST expression: " + str(ast) | ||
) | ||
else: | ||
raise ValueError( | ||
"extract_parameter_set: unknown AST expression: " + str(ast) | ||
) | ||
|
||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ token literal names: | |
null | ||
'[' | ||
']' | ||
'=' | ||
null | ||
null | ||
'#' | ||
null | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,6 @@ ELSE_KW=10 | |
VAR_NAME=11 | ||
'['=1 | ||
']'=2 | ||
'='=3 | ||
'#'=5 | ||
':if'=8 | ||
':then'=9 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.