diff --git a/README.md b/README.md index d1973d0..78fcb73 100644 --- a/README.md +++ b/README.md @@ -86,23 +86,42 @@ session = await ai_engine.create_session(function_group=public_group.uuid) ```python await session.start(objective) ``` - - + #### Querying new messages -You might want to query new messages regularly ... - - - +You might want to query new messages regularly ... ```python - while True: messages: list[ApiBaseMessage] = await session.get_messages() # throttling - sleep(3) + sleep(3) ``` - +#### Execution a function on demand. +This is the first message that should be sent to the AI Engine for execution the function/s of your choice. +The main difference in here it is the AI Engine won't search, therefore decide for you, what is the apt function to fulfill your needs. + +It contains the list of function-ids you want to execute and a function group (for secondary function picks). + +Currently only supported by Next Generation personality. +Don't use this if you already sent 'start' message. + +```python +# init the AI Engine client +from ai_engine_sdk import AiEngine +ai_engine: AiEngine = AiEngine(api_key) +# Create (do not start) a Session +session = await ai_engine.create_session(function_group=function_group.uuid) + +# Execute function. You will receive no response. +await session.execute_function(function_ids=[function_uuid], objective="", context="") + +# In order to get some feedback, gather the messages as regular. +while True: + messages: list[ApiBaseMessage] = await session.get_messages() + # throttling + sleep(3) +``` #### Checking the type of the new message There are 5 different types of messages which are generated by the AI Engine and the SDK implements methods for checking the type of the respective new Message: diff --git a/ai_engine_sdk/api_models/api_models.py b/ai_engine_sdk/api_models/api_models.py index c162d41..f99364c 100644 --- a/ai_engine_sdk/api_models/api_models.py +++ b/ai_engine_sdk/api_models/api_models.py @@ -8,7 +8,7 @@ class ApiMessagePayloadTypes(str, Enum): START = "start" USER_JSON = "user_json" USER_MESSAGE = "user_message" - + EXECUTE_FUNCTIONS = "execute_functions" class ApiMessagePayload(BaseModel): session_id: str @@ -55,6 +55,15 @@ class ApiUserMessageMessage(ApiMessagePayload): user_message: str +class ApiUserMessageExecuteFunctions(ApiMessagePayload): + type: Literal[ApiMessagePayloadTypes.EXECUTE_FUNCTIONS] = ApiMessagePayloadTypes.EXECUTE_FUNCTIONS + + functions: list[str] + objective: str + context: str + + + # ----------- # class ApiNewSessionResponse(BaseModel): diff --git a/ai_engine_sdk/client.py b/ai_engine_sdk/client.py index e56d358..309b28e 100644 --- a/ai_engine_sdk/client.py +++ b/ai_engine_sdk/client.py @@ -27,7 +27,7 @@ from .api_models.api_models import ( ApiNewSessionRequest, is_api_context_json, - ApiStartMessage, ApiMessagePayload, ApiUserJsonMessage, ApiUserMessageMessage + ApiStartMessage, ApiMessagePayload, ApiUserJsonMessage, ApiUserMessageMessage, ApiUserMessageExecuteFunctions ) from .api_models.parsing_utils import get_indexed_task_options_from_raw_api_response from .llm_models import ( @@ -352,12 +352,22 @@ async def delete(self): endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}" ) + async def execute_function(self, function_ids: list[str], objective: str, context: str|None = None): + await self._submit_message( + payload=ApiUserMessageExecuteFunctions.model_validate({ + "functions": function_ids, + "objective": objective, + "context": context or "", + 'session_id': self.session_id, + }) + ) class AiEngine: def __init__(self, api_key: str, options: Optional[dict] = None): self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url self._api_key = api_key + #### # Function groups #### @@ -464,7 +474,7 @@ async def get_functions_by_function_group(self, function_group_id: str) -> list[ if "functions" in raw_response: list( map( - lambda function_name: FunctionGroupFunctions.parse_obj({"name": function_name}), + lambda function_name: FunctionGroupFunctions.model_validate({"name": function_name}), raw_response["functions"] ) ) diff --git a/examples/execute_function.py b/examples/execute_function.py new file mode 100644 index 0000000..73531c0 --- /dev/null +++ b/examples/execute_function.py @@ -0,0 +1,88 @@ +import argparse +import asyncio +import os +from pprint import pprint + +from faker.utils.decorators import lowercase + +from ai_engine_sdk import AiEngine, FunctionGroup, ApiBaseMessage +from ai_engine_sdk.client import Session +from tests.conftest import function_groups + + +async def main( + target_environment: str, + agentverse_api_key: str, + function_uuid: str, + function_group_uuid: str +): + # Request from cli args. + options = {} + if target_environment: + options = {"api_base_url": target_environment} + + ai_engine = AiEngine(api_key=agentverse_api_key, options=options) + + session: Session = await ai_engine.create_session(function_group=function_group_uuid) + await session.execute_function(function_ids=[function_uuid], objective="", context="") + + try: + empty_count = 0 + session_ended = False + + print("Waiting for execution:") + while empty_count < 100: + messages: list[ApiBaseMessage] = await session.get_messages() + if messages: + pprint(messages) + if any((msg.type.lower() == "stop" for msg in messages)): + print("DONE") + break + if len(messages) % 10 == 0: + print("Wait...") + if len(messages) == 0: + empty_count += 1 + else: + empty_count = 0 + + + except Exception as ex: + pprint(ex) + raise + +if __name__ == '__main__': + from dotenv import load_dotenv + load_dotenv() + api_key = os.getenv("AV_API_KEY", "") + + parser = argparse.ArgumentParser() + parser.add_argument( + "-e", + "--target_environment", + type=str, + required=False, + help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production." + ) + parser.add_argument( + "-fg", + "--function_group_uuid", + type=str, + required=True, + ) + parser.add_argument( + "-f", + "--function_uuid", + type=str, + required=True, + ) + args = parser.parse_args() + + result = asyncio.run( + main( + agentverse_api_key=api_key, + target_environment=args.target_environment, + function_group_uuid=args.function_group_uuid, + function_uuid=args.function_uuid + ) + ) + pprint(result) diff --git a/examples/get_function_from_function_group_name.py b/examples/get_function_from_function_group_name.py new file mode 100644 index 0000000..6152a6c --- /dev/null +++ b/examples/get_function_from_function_group_name.py @@ -0,0 +1,63 @@ +import argparse +import asyncio +import os +from pprint import pprint + +from ai_engine_sdk import FunctionGroup, AiEngine +from tests.integration.test_ai_engine_client import api_key + + +async def main( + function_group_name: str, + agentverse_api_key: str, + target_environment: str | None = None, +): + # Request from cli args. + options = {} + if target_environment: + options = {"api_base_url": target_environment} + + ai_engine: AiEngine = AiEngine(api_key=agentverse_api_key, options=options) + function_groups: list[FunctionGroup] = await ai_engine.get_function_groups() + + target_function_group = next((g for g in function_groups if g.name == function_group_name), None) + if target_function_group is None: + raise Exception(f'Could not find "{target_function_group}" function group.') + + return await ai_engine.get_functions_by_function_group(function_group_id=target_function_group.uuid) + + + +if __name__ == "__main__": + from dotenv import load_dotenv + load_dotenv() + api_key = os.getenv("AV_API_KEY", "") + + # Parse CLI arguments + parser = argparse.ArgumentParser() + + parser.add_argument( + "-e", + "--target_environment", + type=str, + required=False, + help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production." + ) + parser.add_argument( + "-fgn", + "--fg_name", + type=str, + required=True, + ) + args = parser.parse_args() + + target_environment = args.target_environment + + res = asyncio.run( + main( + agentverse_api_key=api_key, + function_group_name=args.fg_name, + target_environment=args.target_environment + ) + ) + pprint(res) \ No newline at end of file diff --git a/examples/functions_by_function_group.py b/examples/list_functions_by_function_group_id.py similarity index 100% rename from examples/functions_by_function_group.py rename to examples/list_functions_by_function_group_id.py diff --git a/tests/conftest.py b/tests/conftest.py index b76164a..d9ea123 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,4 +34,17 @@ async def function_groups(ai_engine_client) -> list[FunctionGroup]: # session: Session = await ai_engine_client.create_session( # function_group=function_groups, opts={"model": "next-gen"} # ) -# return session \ No newline at end of file +# return session + + +@pytest.fixture(scope="session") +def valid_public_function_uuid() -> str: + # TODO: Do it programmatically (when test fails bc of it will be good moment) + # 'Cornerstone Software' from Public fg and staging + return "312712ae-eb70-42f7-bb5a-ad21ce6d73c3" + + +@pytest.fixture(scope="session") +def public_function_group() -> FunctionGroup: + # TODO: Do it programmatically (when test fails bc of it will be good moment) + return FunctionGroup(uuid="e504eabb-4bc7-458d-aa8c-7c3748f8952c", name="Public", isPrivate=False) \ No newline at end of file diff --git a/tests/integration/test_ai_engine_client.py b/tests/integration/test_ai_engine_client.py index af6ef08..9c31a3b 100644 --- a/tests/integration/test_ai_engine_client.py +++ b/tests/integration/test_ai_engine_client.py @@ -60,6 +60,16 @@ async def test_create_session(self, ai_engine_client: AiEngine): # await ai_engine_client.delete_function_group() + @pytest.mark.asyncio + async def test_execute_function(self, ai_engine_client: AiEngine, public_function_group: FunctionGroup, valid_public_function_uuid: str): + session: Session = await ai_engine_client.create_session(function_group=public_function_group.uuid) + result = await session.execute_function( + function_ids=[valid_public_function_uuid], + objective="Test software", + context="" + ) + + @pytest.mark.asyncio async def test_create_function_group_and_list_them(self, ai_engine_client: AiEngine): name = fake.company()