Skip to content

Commit

Permalink
Merge pull request #15 from fetchai/feat/start-session-executing-func…
Browse files Browse the repository at this point in the history
…tion

session: execute function
  • Loading branch information
qati authored Sep 18, 2024
2 parents 631cf1e + a08f118 commit 229a0c0
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 13 deletions.
37 changes: 28 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,42 @@ session = await ai_engine.create_session(function_group=public_group.uuid)
```python
await session.start(objective)
```



#### Querying new messages

You might want to query new messages regularly ...



You might want to query new messages regularly ...
```python

while True:
messages: list[ApiBaseMessage] = await session.get_messages()
# throttling
sleep(3)
sleep(3)
```


#### Execution a function on demand.
This is the first message that should be sent to the AI Engine for execution the function/s of your choice.
The main difference in here it is the AI Engine won't search, therefore decide for you, what is the apt function to fulfill your needs.

It contains the list of function-ids you want to execute and a function group (for secondary function picks).

Currently only supported by Next Generation personality.
Don't use this if you already sent 'start' message.

```python
# init the AI Engine client
from ai_engine_sdk import AiEngine
ai_engine: AiEngine = AiEngine(api_key)
# Create (do not start) a Session
session = await ai_engine.create_session(function_group=function_group.uuid)

# Execute function. You will receive no response.
await session.execute_function(function_ids=[function_uuid], objective="", context="")

# In order to get some feedback, gather the messages as regular.
while True:
messages: list[ApiBaseMessage] = await session.get_messages()
# throttling
sleep(3)
```
#### Checking the type of the new message

There are 5 different types of messages which are generated by the AI Engine and the SDK implements methods for checking the type of the respective new <code>Message</code>:
Expand Down
11 changes: 10 additions & 1 deletion ai_engine_sdk/api_models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class ApiMessagePayloadTypes(str, Enum):
START = "start"
USER_JSON = "user_json"
USER_MESSAGE = "user_message"

EXECUTE_FUNCTIONS = "execute_functions"

class ApiMessagePayload(BaseModel):
session_id: str
Expand Down Expand Up @@ -55,6 +55,15 @@ class ApiUserMessageMessage(ApiMessagePayload):
user_message: str


class ApiUserMessageExecuteFunctions(ApiMessagePayload):
type: Literal[ApiMessagePayloadTypes.EXECUTE_FUNCTIONS] = ApiMessagePayloadTypes.EXECUTE_FUNCTIONS

functions: list[str]
objective: str
context: str



# -----------

# class ApiNewSessionResponse(BaseModel):
Expand Down
14 changes: 12 additions & 2 deletions ai_engine_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .api_models.api_models import (
ApiNewSessionRequest,
is_api_context_json,
ApiStartMessage, ApiMessagePayload, ApiUserJsonMessage, ApiUserMessageMessage
ApiStartMessage, ApiMessagePayload, ApiUserJsonMessage, ApiUserMessageMessage, ApiUserMessageExecuteFunctions
)
from .api_models.parsing_utils import get_indexed_task_options_from_raw_api_response
from .llm_models import (
Expand Down Expand Up @@ -352,12 +352,22 @@ async def delete(self):
endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}"
)

async def execute_function(self, function_ids: list[str], objective: str, context: str|None = None):
await self._submit_message(
payload=ApiUserMessageExecuteFunctions.model_validate({
"functions": function_ids,
"objective": objective,
"context": context or "",
'session_id': self.session_id,
})
)

class AiEngine:
def __init__(self, api_key: str, options: Optional[dict] = None):
self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url
self._api_key = api_key


####
# Function groups
####
Expand Down Expand Up @@ -464,7 +474,7 @@ async def get_functions_by_function_group(self, function_group_id: str) -> list[
if "functions" in raw_response:
list(
map(
lambda function_name: FunctionGroupFunctions.parse_obj({"name": function_name}),
lambda function_name: FunctionGroupFunctions.model_validate({"name": function_name}),
raw_response["functions"]
)
)
Expand Down
88 changes: 88 additions & 0 deletions examples/execute_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import argparse
import asyncio
import os
from pprint import pprint

from faker.utils.decorators import lowercase

from ai_engine_sdk import AiEngine, FunctionGroup, ApiBaseMessage
from ai_engine_sdk.client import Session
from tests.conftest import function_groups


async def main(
target_environment: str,
agentverse_api_key: str,
function_uuid: str,
function_group_uuid: str
):
# Request from cli args.
options = {}
if target_environment:
options = {"api_base_url": target_environment}

ai_engine = AiEngine(api_key=agentverse_api_key, options=options)

session: Session = await ai_engine.create_session(function_group=function_group_uuid)
await session.execute_function(function_ids=[function_uuid], objective="", context="")

try:
empty_count = 0
session_ended = False

print("Waiting for execution:")
while empty_count < 100:
messages: list[ApiBaseMessage] = await session.get_messages()
if messages:
pprint(messages)
if any((msg.type.lower() == "stop" for msg in messages)):
print("DONE")
break
if len(messages) % 10 == 0:
print("Wait...")
if len(messages) == 0:
empty_count += 1
else:
empty_count = 0


except Exception as ex:
pprint(ex)
raise

if __name__ == '__main__':
from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv("AV_API_KEY", "")

parser = argparse.ArgumentParser()
parser.add_argument(
"-e",
"--target_environment",
type=str,
required=False,
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production."
)
parser.add_argument(
"-fg",
"--function_group_uuid",
type=str,
required=True,
)
parser.add_argument(
"-f",
"--function_uuid",
type=str,
required=True,
)
args = parser.parse_args()

result = asyncio.run(
main(
agentverse_api_key=api_key,
target_environment=args.target_environment,
function_group_uuid=args.function_group_uuid,
function_uuid=args.function_uuid
)
)
pprint(result)
63 changes: 63 additions & 0 deletions examples/get_function_from_function_group_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import asyncio
import os
from pprint import pprint

from ai_engine_sdk import FunctionGroup, AiEngine
from tests.integration.test_ai_engine_client import api_key


async def main(
function_group_name: str,
agentverse_api_key: str,
target_environment: str | None = None,
):
# Request from cli args.
options = {}
if target_environment:
options = {"api_base_url": target_environment}

ai_engine: AiEngine = AiEngine(api_key=agentverse_api_key, options=options)
function_groups: list[FunctionGroup] = await ai_engine.get_function_groups()

target_function_group = next((g for g in function_groups if g.name == function_group_name), None)
if target_function_group is None:
raise Exception(f'Could not find "{target_function_group}" function group.')

return await ai_engine.get_functions_by_function_group(function_group_id=target_function_group.uuid)



if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv("AV_API_KEY", "")

# Parse CLI arguments
parser = argparse.ArgumentParser()

parser.add_argument(
"-e",
"--target_environment",
type=str,
required=False,
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production."
)
parser.add_argument(
"-fgn",
"--fg_name",
type=str,
required=True,
)
args = parser.parse_args()

target_environment = args.target_environment

res = asyncio.run(
main(
agentverse_api_key=api_key,
function_group_name=args.fg_name,
target_environment=args.target_environment
)
)
pprint(res)
File renamed without changes.
15 changes: 14 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,17 @@ async def function_groups(ai_engine_client) -> list[FunctionGroup]:
# session: Session = await ai_engine_client.create_session(
# function_group=function_groups, opts={"model": "next-gen"}
# )
# return session
# return session


@pytest.fixture(scope="session")
def valid_public_function_uuid() -> str:
# TODO: Do it programmatically (when test fails bc of it will be good moment)
# 'Cornerstone Software' from Public fg and staging
return "312712ae-eb70-42f7-bb5a-ad21ce6d73c3"


@pytest.fixture(scope="session")
def public_function_group() -> FunctionGroup:
# TODO: Do it programmatically (when test fails bc of it will be good moment)
return FunctionGroup(uuid="e504eabb-4bc7-458d-aa8c-7c3748f8952c", name="Public", isPrivate=False)
10 changes: 10 additions & 0 deletions tests/integration/test_ai_engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ async def test_create_session(self, ai_engine_client: AiEngine):
# await ai_engine_client.delete_function_group()


@pytest.mark.asyncio
async def test_execute_function(self, ai_engine_client: AiEngine, public_function_group: FunctionGroup, valid_public_function_uuid: str):
session: Session = await ai_engine_client.create_session(function_group=public_function_group.uuid)
result = await session.execute_function(
function_ids=[valid_public_function_uuid],
objective="Test software",
context=""
)


@pytest.mark.asyncio
async def test_create_function_group_and_list_them(self, ai_engine_client: AiEngine):
name = fake.company()
Expand Down

0 comments on commit 229a0c0

Please sign in to comment.