-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rework runtime system: untie it from config and env
- Loading branch information
Showing
12 changed files
with
202 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,5 +30,4 @@ pytest==8.3.3 | |
pytest-asyncio==0.24.0 | ||
python-dotenv==1.0.1 | ||
sniffio==1.3.1 | ||
tqdm==4.67.0 | ||
typing_extensions==4.12.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,13 @@ | ||
class Env: | ||
class Env(dict): | ||
|
||
def __init__(self, env={}, parent=None): | ||
self.env = env | ||
self.parent = parent | ||
def __init__(self, *args, **kwargs): | ||
self.env = dict(*args, **kwargs) | ||
|
||
def set(self, variable, value): | ||
self.env[variable] = value | ||
|
||
def get(self, variable): | ||
if variable in self.env: | ||
return self.env[variable] | ||
elif self.parent is not None: | ||
self.parent.lookup(variable) | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,128 +1,150 @@ | ||
from config import Config | ||
from env import Env | ||
from runtime import Runtime | ||
from runtime import BaseRuntime | ||
from typing import AsyncGenerator | ||
from loader import extract_variables | ||
|
||
|
||
async def _eval_exprs(exprs, runtime): | ||
"""A helper for eval_ast""" | ||
for expr in exprs: | ||
async for chunk in eval_ast(expr, runtime): | ||
yield chunk | ||
|
||
|
||
async def _collect_exprs(exprs, runtime): | ||
res = "" | ||
for expr in exprs: | ||
async for chunk in eval_ast(expr, runtime): | ||
res += chunk | ||
return res | ||
|
||
|
||
IF_PROMPT = """Please determine if the following statement is true. | ||
Do not write any other output, answer just "true" or "false". | ||
The statement: | ||
""" | ||
|
||
|
||
async def eval_ast(ast, runtime): | ||
if isinstance(ast, list): | ||
# TODO: is this case needed? | ||
async for expr in _eval_exprs(ast, runtime): | ||
yield expr | ||
elif ast["type"] == "text": | ||
yield ast["text"] | ||
elif ast["type"] == "metaprompt": | ||
async for expr in _eval_exprs(ast["exprs"], runtime): | ||
yield expr | ||
elif ast["type"] == "var": | ||
value = runtime.env.get(ast["name"]) | ||
if value is None: | ||
raise ValueError(f"Failed to look up: {ast['name']}") | ||
else: | ||
yield value | ||
elif ast["type"] == "use": | ||
parameters = ast["parameters"] | ||
module_name = ast["module_name"] | ||
loaded_ast = runtime.load_module(module_name) | ||
required_variables = extract_variables(loaded_ast) | ||
for required in required_variables: | ||
if required not in parameters: | ||
raise ImportError( | ||
f"Module {module_name} requires {required} as a parameter, but it was not provided" | ||
) | ||
old_env = runtime.env | ||
# TODO: persist some variables? | ||
evaluated_parameters = {} | ||
for parameter in parameters: | ||
evaluated_parameters[parameter] = await _collect_exprs( | ||
parameters[parameter], runtime | ||
) | ||
if "MODEL" not in evaluated_parameters: | ||
evaluated_parameters["MODEL"] = old_env.get("MODEL") | ||
runtime.env = Env(evaluated_parameters) | ||
async for expr in eval_ast(loaded_ast, runtime): | ||
yield expr | ||
runtime.env = old_env | ||
elif ast["type"] == "assign": | ||
var_name = ast["name"] | ||
value = await _collect_exprs(ast["exprs"], runtime) | ||
runtime.set_variable(var_name, value) | ||
elif ast["type"] == "meta": | ||
chunks = [] | ||
for expr in ast["exprs"]: | ||
async for chunk in eval_ast(expr, runtime): | ||
chunks.append(chunk) | ||
prompt = "".join(chunks) | ||
async for chunk in runtime.stream_invoke(prompt): | ||
async def eval_ast(ast, config, runtime): | ||
env = Env(**config.parameters) | ||
env.set("MODEL", config.model.strip()) | ||
|
||
async def _eval_exprs(exprs): | ||
"""A helper for eval_ast""" | ||
for expr in exprs: | ||
async for chunk in _eval_ast(expr): | ||
yield chunk | ||
|
||
async def _collect_exprs(exprs): | ||
"""_eval_ast, but returns everything as text""" | ||
res = "" | ||
for expr in exprs: | ||
async for chunk in _eval_ast(expr): | ||
res += chunk | ||
return res | ||
|
||
def get_current_model_provider(): | ||
nonlocal env | ||
model_name = env.get("MODEL").strip() # can't be empty | ||
provider = config.providers.get(model_name) | ||
if provider is None: | ||
raise ValueError(f"Model not available: {model_name}") | ||
return provider | ||
|
||
async def stream_invoke(prompt: str) -> AsyncGenerator[str, None]: | ||
provider = get_current_model_provider() | ||
async for chunk in provider.ainvoke(prompt, "user"): | ||
yield chunk | ||
elif ast["type"] == "exprs": | ||
for expr in ast["exprs"]: | ||
async for chunk in eval_ast(expr, runtime): | ||
|
||
async def invoke(self, prompt: str) -> str: | ||
res = "" | ||
async for chunk in self.stream_invoke(prompt): | ||
res += chunk | ||
return res | ||
|
||
async def _eval_ast(ast): | ||
nonlocal env, runtime | ||
if isinstance(ast, list): | ||
# TODO: is this case needed? | ||
async for chunk in _eval_exprs(ast): | ||
yield chunk | ||
elif ast["type"] == "text": | ||
yield ast["text"] | ||
elif ast["type"] == "metaprompt": | ||
async for chunk in _eval_exprs(ast["exprs"]): | ||
yield chunk | ||
elif ast["type"] == "if_then_else": | ||
# evaluate the conditional | ||
condition_chunks = [] | ||
async for chunk in eval_ast(ast["condition"], runtime): | ||
condition_chunks.append(chunk) | ||
condition = "".join(condition_chunks) | ||
prompt_result = "" | ||
MAX_RETRIES = 3 | ||
retries = 0 | ||
prompt = IF_PROMPT + condition | ||
while prompt_result != "true" and prompt_result != "false": | ||
if retries >= MAX_RETRIES: | ||
raise ValueError( | ||
"Failed to answer :if prompt: " | ||
+ prompt | ||
+ "\nOutput: " | ||
+ prompt_result | ||
elif ast["type"] == "var": | ||
value = env.get(ast["name"]) | ||
if value is None: | ||
raise ValueError(f"Failed to look up: {ast['name']}") | ||
else: | ||
yield value | ||
elif ast["type"] == "use": | ||
parameters = ast["parameters"] | ||
module_name = ast["module_name"] | ||
loaded_ast = runtime.load_module(module_name) | ||
required_variables = extract_variables(loaded_ast) | ||
for required in required_variables: | ||
if required not in parameters: | ||
raise ImportError( | ||
f"Module {module_name} requires {required} as a parameter, but it was not provided" | ||
) | ||
# TODO: persist some variables? | ||
evaluated_parameters = {} | ||
for parameter in parameters: | ||
evaluated_parameters[parameter] = await _collect_exprs( | ||
parameters[parameter] | ||
) | ||
prompt_result = await runtime.invoke(prompt) | ||
prompt_result = prompt_result.strip() | ||
retries += 1 | ||
if prompt_result == "true": | ||
async for chunk in eval_ast(ast["then"], runtime): | ||
old_env = env | ||
if "MODEL" not in evaluated_parameters: | ||
evaluated_parameters["MODEL"] = old_env.get("MODEL") | ||
env = Env(evaluated_parameters) | ||
async for chunk in _eval_ast(loaded_ast): | ||
yield chunk | ||
else: # false | ||
async for chunk in eval_ast(ast["else"], runtime): | ||
env = old_env | ||
elif ast["type"] == "assign": | ||
var_name = ast["name"] | ||
value = (await _collect_exprs(ast["exprs"])).strip() | ||
if var_name == "STATUS": | ||
runtime.set_status(value) | ||
env.set(var_name, value) | ||
elif ast["type"] == "meta": | ||
chunks = [] | ||
for expr in ast["exprs"]: | ||
async for chunk in _eval_ast(expr): | ||
chunks.append(chunk) | ||
prompt = "".join(chunks) | ||
async for chunk in stream_invoke(prompt): | ||
yield chunk | ||
else: | ||
raise ValueError("Runtime AST evaluation error: " + str(ast)) | ||
elif ast["type"] == "exprs": | ||
for expr in ast["exprs"]: | ||
async for chunk in _eval_ast(expr): | ||
yield chunk | ||
elif ast["type"] == "if_then_else": | ||
# evaluate the conditional | ||
condition = await _collect_exprs(ast["condition"]) | ||
prompt_result = "" | ||
MAX_RETRIES = 3 | ||
retries = 0 | ||
prompt = IF_PROMPT + condition | ||
while prompt_result != "true" and prompt_result != "false": | ||
if retries >= MAX_RETRIES: | ||
raise ValueError( | ||
"Failed to answer :if prompt: " | ||
+ prompt | ||
+ "\nOutput: " | ||
+ prompt_result | ||
) | ||
prompt_result = (await invoke(prompt)).strip() | ||
retries += 1 | ||
if prompt_result == "true": | ||
async for chunk in _eval_ast(ast["then"]): | ||
yield chunk | ||
else: # false | ||
async for chunk in _eval_ast(ast["else"]): | ||
yield chunk | ||
else: | ||
raise ValueError("Runtime AST evaluation error: " + str(ast)) | ||
|
||
async for chunk in _eval_ast(ast): | ||
yield chunk | ||
|
||
|
||
async def stream_eval_metaprompt( | ||
metaprompt, config: Config | ||
metaprompt, config: Config, runtime: BaseRuntime | ||
) -> AsyncGenerator[str, None]: | ||
env = Env(env=config.parameters) | ||
runtime = Runtime(config, env) | ||
async for chunk in eval_ast(metaprompt, runtime): | ||
yield chunk | ||
|
||
|
||
async def eval_metaprompt(metaprompt, config: Config): | ||
async def eval_metaprompt(metaprompt, config: Config, runtime: BaseRuntime): | ||
res = "" | ||
async for chunk in stream_eval_metaprompt(metaprompt, config): | ||
async for chunk in stream_eval_metaprompt(metaprompt, config, runtime): | ||
res += chunk | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.