Skip to content

Commit

Permalink
refactor: shifting memory mechanism further down the call stack
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alherrera authored and sfc-gh-twhite committed Dec 6, 2024
1 parent 99d1241 commit dedcbf6
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions agent_gateway/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
max_retries: int = 2,
planner_llm: str = "mistral-large2",
agent_llm: str = "mistral-large2",
memory: bool = True,
planner_example_prompt: str = SNOWFLAKE_PLANNER_PROMPT,
planner_example_prompt_replan: Optional[str] = None,
planner_stop: Optional[list[str]] = [END_OF_PLAN],
Expand All @@ -155,6 +156,7 @@ def __init__(
max_retries: Maximum number of replans to do. Defaults to 2.
planner_llm: Name of Snowflake Cortex LLM to use for planning.
agent_llm: Name of Snowflake Cortex LLM to use for planning.
memory: Boolean to turn on memory mechanism or not. Defaults to False.
planner_example_prompt: Example prompt for planning. Defaults to SNOWFLAKE_PLANNER_PROMPT.
planner_example_prompt_replan: Example prompt for replanning.
Assign this if you want to use different example prompt for replanning.
Expand Down Expand Up @@ -192,7 +194,9 @@ def __init__(
self.max_retries = max_retries

# basic memory
self.memory = []
self.memory = memory
if self.memory:
self.memory_context = []

# callbacks
self.planner_callback = None
Expand Down Expand Up @@ -271,6 +275,7 @@ def _extract_replan_message(self, raw_answer):
"rephrasing your request or validate that the provided tools contain "
"sufficient information."
)

def _generate_context_for_replanner(
self, tasks: Mapping[int, Task], fusion_thought: str
) -> str:
Expand Down Expand Up @@ -343,9 +348,6 @@ def __call__(self, input: str):
result = []
error = []

if len(self.memory) >= 1:
input = f"My previous question/answer was: {self.memory[0]}\n. If needed, use that context and this {input} to answer my question. Otherwise just give me an answer to: {input} "

thread = threading.Thread(target=self.run_async, args=(input, result, error))
thread.start()
thread.join()
Expand All @@ -356,10 +358,6 @@ def __call__(self, input: str):
if not result:
raise AgentGatewayError("Unable to retrieve response. Result is empty.")

max_memory = 3 # TODO consider exposing this to users
if len(self.memory) <= max_memory:
self.memory.append({"Question:": input, "Answer": result[0]})

return result[0]

def handle_exception(self, loop, context):
Expand Down Expand Up @@ -396,12 +394,17 @@ def run_async(self, input, result, error):
async def acall(
self,
input: str,
# inputs: Dict[str, Any]
) -> Dict[str, Any]:
contexts = []
fusion_thought = ""
agent_scratchpad = ""
inputs = {"input": input}

if self.memory:
input_with_mem = f"My previous question/answer was: {self.memory_context}\n. If needed, use that context and this {input} to answer my question. Otherwise just give me an answer to: {input} "
inputs = {"input": input_with_mem}
else:
inputs = {"input": input}

for i in range(self.max_retries):
is_first_iter = i == 0
is_final_iter = i == self.max_retries - 1
Expand Down Expand Up @@ -464,4 +467,9 @@ async def acall(
formatted_contexts = self._format_contexts(contexts)
inputs["context"] = formatted_contexts

max_memory = 3 # TODO consider exposing this to users

if len(self.memory_context) <= max_memory:
self.memory_context.append({"Question:": input, "Answer": answer})

return answer

0 comments on commit dedcbf6

Please sign in to comment.