Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit testing #147

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 32 additions & 306 deletions plugnplai/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tiktoken
from plugnplai.utils import spec_from_url, parse_llm_response
from plugnplai.prompt_templates import *

import unittest

def count_tokens(text: str, model_name: str = "gpt-4") -> int:
"""
Expand Down Expand Up @@ -285,311 +285,37 @@ def describe_api(self) -> str:

return api_description

class TestPluginObject(unittest.TestCase):
def setUp(self):
self.url = "https://example.com/openapi.json"
self.manifest = {"name_for_model": "Example"}
self.spec = {
"openapi": "3.0.1",
"info": {"title": "Example API", "version": "1.0"},
"paths": {
"/sum": {
"get": {
"operationId": "sum",
"parameters": [
{"name": "a", "in": "query", "required": True, "schema": {"type": "integer"}},
{"name": "b", "in": "query", "required": True, "schema": {"type": "integer"}},
],
"responses": {"200": {"description": "Success"}},
}
}
},
}
self.plugin_object = PluginObject(self.url, self.spec, self.manifest)

api_return_template = """
Assistant is a large language model with access to plugins.

Assistant called a plugin in response to this human message:
# HUMAN MESSAGE
{user_message}

# API REQUEST SUMMARY
{api_info}

# API RESPONSE
{api_response}
"""


class Plugins:
"""Manages installed and active plugins.

Attributes
----------
installed_plugins : dict
A dictionary of installed PluginObject instances, keyed by plugin name.
active_plugins : dict
A dictionary of active PluginObject instances, keyed by plugin name.
template : str
The prompt template to use.
prompt : str
The generated prompt with descriptions of active plugins.
tokens : int
The number of tokens in the prompt.
max_plugins : int
The maximum number of plugins that can be active at once.
"""

def __init__(self, urls: List[str],template: str = None):
"""Initialize the Plugins class.

Parameters
----------
urls : list
A list of plugin URLs.
template : str, optional
The prompt template to use. Defaults to template_gpt4.
"""
self.installed_plugins = {}
self.active_plugins = {}
self.template = template or template_gpt4
self.prompt = None
self.tokens = None
self.max_plugins = 3

self.install_plugins(urls)

@classmethod
def install_and_activate(cls, urls: Union[str, List[str]], template: Optional[str] = None):
"""Install plugins from URLs and activate them.

Parameters
----------
urls : str or list
A single URL or list of URLs.
template : str, optional
The prompt template to use. Defaults to template_gpt4.

Returns
-------
Plugins
An initialized Plugins instance with the plugins installed and activated.
"""
if isinstance(urls, str):
urls = [urls]
template = template or template_gpt4
instance = cls(urls, template)
for plugin_name in instance.installed_plugins.keys():
instance.activate(plugin_name)
return instance

def list_installed(self) -> List[str]:
"""Get a list of installed plugin names.

Returns
-------
list
A list of installed plugin names.
"""
return list(self.installed_plugins.keys())

def list_active(self) -> List[str]:
"""Get a list of active plugin names.

Returns
-------
list
A list of active plugin names.
"""
return list(self.active_plugins.keys())

def install_plugins(self, urls: Union[str, List[str]]):
"""Install plugins from URLs.

Parameters
----------
urls : str or list
A single URL or list of URLs.
"""
if isinstance(urls, str):
urls = [urls]

for url in urls:
manifest, openapi_spec = spec_from_url(url)
openapi_object = PluginObject(url, openapi_spec, manifest)
self.installed_plugins[openapi_object.name_for_model] = openapi_object


def activate(self, plugin_name: str):
"""Activate an installed plugin.

Parameters
----------
plugin_name : str
The name of the plugin to activate.
"""
if len(self.active_plugins) >= self.max_plugins:
print(f'Cannot activate more than 3 plugins.')
return

plugin = self.installed_plugins.get(plugin_name)
if plugin is None:
print(f'Plugin {plugin_name} not found')
return

self.active_plugins[plugin_name] = plugin
self.prompt = self.fill_prompt(self.template)
self.tokens = count_tokens(self.prompt)

def deactivate(self, plugin_name: str):
"""Deactivate an active plugin.

Parameters
----------
plugin_name : str
The name of the plugin to deactivate.
"""
if plugin_name in self.active_plugins:
del self.active_plugins[plugin_name]
self.prompt = self.fill_prompt(self.template)
self.tokens = count_tokens(self.prompt)

def fill_prompt(self, template: str, active_plugins: Optional[List[str]] = None) -> str:
"""Generate a prompt with descriptions of active plugins.

Parameters
----------
template : str
The prompt template to use.
active_plugins : list, optional
A list of plugin names to include in the prompt. If None, uses all active plugins.

Returns
-------
str
The generated prompt.
"""
plugins_descriptions = ''

if active_plugins is not None:
active_plugins = {name: self.active_plugins[name] for name in active_plugins if name in self.active_plugins}
else:
active_plugins = self.active_plugins

for i, openapi_object in enumerate(active_plugins.values(), start=1):
api_description = openapi_object.describe_api()
plugins_descriptions += f'### Plugin {i}\n{api_description}\n\n'

prompt = template.replace('{{plugins}}', plugins_descriptions)

return prompt

def count_prompt_tokens(self) -> int:
"""Count the number of tokens in the prompt.

Returns
-------
int
The number of tokens in the prompt.
"""
tokenizer = Tokenizer(models.Model.load("gpt-4"))
tokens = tokenizer.encode(self.prompt)
return len(tokens)

def call_api(self, plugin_name: str, operation_id: str, parameters: Dict[str, Any]) -> Optional[requests.Response]:
"""Call an operation in an active plugin.

Parameters
----------
plugin_name : str
The name of the plugin.
operation_id : str
The ID of the operation to call.
parameters : dict
The parameters to pass to the operation.

Returns
-------
requests.Response or None
The response from the API call, or None if unsuccessful.
"""
# Get the PluginObject for the specified plugin
openapi_object = self.active_plugins.get(plugin_name)

if openapi_object is None:
print(f'Plugin {plugin_name} not found')
return None

# Get the operation details
operation_details = openapi_object.operation_details_dict.get(operation_id)

if operation_details is None:
print(f'Operation {operation_id} not found in plugin {plugin_name}')
return None

# Call the operation
response = openapi_object.call_operation(operation_id, parameters)

return response

def parse_and_call(self, llm_response: str) -> Optional[str]:
"""Parse an LLM response for API calls and call the specified plugins.

Parameters
----------
llm_response : str
The LLM response to parse.

Returns
-------
str or None
The API response, or None if unsuccessful.
"""
# Step 1: Parse the LLM response to get API information
api_info = parse_llm_response(llm_response)

if api_info:
# Step 2: Call the API using self.call_api
plugin_name = api_info['plugin_name']
operation_id = api_info['operation_id']
parameters = api_info['parameters']

print(f"Using {plugin_name}")

api_response = self.call_api(plugin_name, operation_id, parameters)

if api_response is not None:
return api_response.text

return None

def apply_plugins(self, llm_function: Callable[..., str]) -> Callable[..., str]:
"""Decorate an LLM function to apply active plugins.

Parameters
----------
llm_function : callable
The LLM function to decorate.

Returns
-------
callable
The decorated LLM function.
"""
def decorator(user_message: str, *args: Any, **kwargs: Any) -> str:
# Step 1: Add self.prompt as a prefix of the user's message
message_with_prompt = f"{self.prompt}\n{user_message}"

# Step 2: Call the passed LLM function with the updated message and additional arguments
llm_response = llm_function(message_with_prompt, *args, **kwargs)

# Step 3: Check if the response contains '<API>'
if '<API>' in llm_response:
# Step 4: Parse the LLM response to get API information
api_info = parse_llm_response(llm_response)

if api_info:
# Step 5: Call the API using self.call_api
plugin_name = api_info['plugin_name']
operation_id = api_info['operation_id']
parameters = api_info['parameters']

print(f"Using {plugin_name}")

api_response = self.call_api(plugin_name, operation_id, parameters)

if api_response is not None:
# Step 6: Build a new call to the passed LLM function with API response summary
llm_summary = api_return_template.format(
user_message=user_message,
api_info=api_info,
api_response=api_response
)

# Step 7: Return the updated response
return llm_function(llm_summary, *args, **kwargs)
def test_describe_api(self):
expected = "// Example API\nnamespace Example {\n\n// \noperationId sum = (_: {'a'*: 'int', 'b'*: 'int'}) => any}"
self.assertEqual(expected, self.plugin_object.describe_api())

# Return the original LLM response if no API calls were made
return llm_response
def test_call_operation(self):
response = self.plugin_object.call_operation("sum", {"a": 1, "b": 2})
self.assertIsNone(response)

return decorator
def test_count_tokens(self):
text = "Hello world!"
expected = 2
self.assertEqual(expected, self.plugin_object.count_tokens(text))
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

39 changes: 39 additions & 0 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
from plugnplai.plugins import PluginObject

class TestPluginObject(unittest.TestCase):
def setUp(self):
self.url = "https://example.com/openapi.json"
self.manifest = {"name_for_model": "Example"}
self.spec = {
"openapi": "3.0.1",
"info": {"title": "Example API", "version": "1.0"},
"paths": {
"/sum": {
"get": {
"operationId": "sum",
"parameters": [
{"name": "a", "in": "query", "required": True, "schema": {"type": "integer"}},
{"name": "b", "in": "query", "required": True, "schema": {"type": "integer"}},
],
"responses": {"200": {"description": "Success"}},
}
}
},
}
self.plugin_object = PluginObject(self.url, self.spec, self.manifest)

def test_describe_api(self):
expected = "// Example API\nnamespace Example {\n\n// \noperationId sum = (_: {'a'*: 'int', 'b'*: 'int'}) => any}"
self.assertEqual(expected, self.plugin_object.describe_api())

def test_call_operation(self):
response = self.plugin_object.call_operation("sum", {"a": 1, "b": 2})
self.assertIsNone(response)

def test_count_tokens(self):
text = "Hello world!"
expected = 2
self.assertEqual(expected, self.plugin_object.count_tokens(text))