Skip to content

Commit

Permalink
feat: Add Contract Tests for new Gen AI attributes for foundational m…
Browse files Browse the repository at this point in the history
…odels (#292)

contract tests for new gen_ai inference parameters added in 


#290

<img width="1563" alt="image"
src="https://github.com/user-attachments/assets/3ea5979d-43b2-43d6-8730-708855969d8a">

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Michael He <[email protected]>
  • Loading branch information
liustve and yiyuan-he authored Nov 22, 2024
1 parent d305721 commit 642427e
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 31 deletions.
164 changes: 148 additions & 16 deletions contract-tests/images/applications/botocore/botocore_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import tempfile
from collections import namedtuple
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from io import BytesIO
from threading import Thread

import boto3
import requests
from botocore.client import BaseClient
from botocore.config import Config
from botocore.exceptions import ClientError
from botocore.response import StreamingBody
from typing_extensions import Tuple, override

_PORT: int = 8080
Expand Down Expand Up @@ -285,28 +287,22 @@ def _handle_bedrock_request(self) -> None:
},
)
elif self.in_path("invokemodel/invoke-model"):
model_id, request_body, response_body = get_model_request_response(self.path)

set_main_status(200)
bedrock_runtime_client.meta.events.register(
"before-call.bedrock-runtime.InvokeModel",
inject_200_success,
)
model_id = "amazon.titan-text-premier-v1:0"
user_message = "Describe the purpose of a 'hello world' program in one line."
prompt = f"<s>[INST] {user_message} [/INST]"
body = json.dumps(
{
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": 3072,
"stopSequences": [],
"temperature": 0.7,
"topP": 0.9,
},
}
lambda **kwargs: inject_200_success(
modelId=model_id,
body=response_body,
**kwargs,
),
)
accept = "application/json"
content_type = "application/json"
bedrock_runtime_client.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
bedrock_runtime_client.invoke_model(
body=request_body, modelId=model_id, accept=accept, contentType=content_type
)
else:
set_main_status(404)

Expand Down Expand Up @@ -378,6 +374,137 @@ def _end_request(self, status_code: int):
self.end_headers()


def get_model_request_response(path):
prompt = "Describe the purpose of a 'hello world' program in one line."
model_id = ""
request_body = {}
response_body = {}

if "amazon.titan" in path:
model_id = "amazon.titan-text-premier-v1:0"

request_body = {
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": 3072,
"stopSequences": [],
"temperature": 0.7,
"topP": 0.9,
},
}

response_body = {
"inputTextTokenCount": 15,
"results": [
{
"tokenCount": 13,
"outputText": "text-test-response",
"completionReason": "CONTENT_FILTERED",
},
],
}

if "anthropic.claude" in path:
model_id = "anthropic.claude-v2:1"

request_body = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 1000,
"temperature": 0.99,
"top_p": 1,
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}],
},
],
}

response_body = {
"stop_reason": "end_turn",
"usage": {
"input_tokens": 15,
"output_tokens": 13,
},
}

if "meta.llama" in path:
model_id = "meta.llama2-13b-chat-v1"

request_body = {"prompt": prompt, "max_gen_len": 512, "temperature": 0.5, "top_p": 0.9}

response_body = {"prompt_token_count": 31, "generation_token_count": 49, "stop_reason": "stop"}

if "cohere.command" in path:
model_id = "cohere.command-r-v1:0"

request_body = {
"chat_history": [],
"message": prompt,
"max_tokens": 512,
"temperature": 0.5,
"p": 0.65,
}

response_body = {
"chat_history": [
{"role": "USER", "message": prompt},
{"role": "CHATBOT", "message": "test-text-output"},
],
"finish_reason": "COMPLETE",
"text": "test-generation-text",
}

if "ai21.jamba" in path:
model_id = "ai21.jamba-1-5-large-v1:0"

request_body = {
"messages": [
{
"role": "user",
"content": prompt,
},
],
"top_p": 0.8,
"temperature": 0.6,
"max_tokens": 512,
}

response_body = {
"stop_reason": "end_turn",
"usage": {
"prompt_tokens": 21,
"completion_tokens": 24,
},
"choices": [
{"finish_reason": "stop"},
],
}

if "mistral" in path:
model_id = "mistral.mistral-7b-instruct-v0:2"

request_body = {
"prompt": prompt,
"max_tokens": 4096,
"temperature": 0.75,
"top_p": 0.99,
}

response_body = {
"outputs": [
{
"text": "test-output-text",
"stop_reason": "stop",
},
]
}

json_bytes = json.dumps(response_body).encode("utf-8")

return model_id, json.dumps(request_body), StreamingBody(BytesIO(json_bytes), len(json_bytes))


def set_main_status(status: int) -> None:
RequestHandler.main_status = status

Expand Down Expand Up @@ -490,11 +617,16 @@ def inject_200_success(**kwargs):
guardrail_arn = kwargs.get("guardrailArn")
if guardrail_arn is not None:
response_body["guardrailArn"] = guardrail_arn
model_id = kwargs.get("modelId")
if model_id is not None:
response_body["modelId"] = model_id

HTTPResponse = namedtuple("HTTPResponse", ["status_code", "headers", "body"])
headers = kwargs.get("headers", {})
body = kwargs.get("body", "")
response_body["body"] = body
http_response = HTTPResponse(200, headers=headers, body=body)

return http_response, response_body


Expand Down
8 changes: 7 additions & 1 deletion contract-tests/tests/test/amazon/base/contract_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ def _assert_int_attribute(self, attributes_dict: Dict[str, AnyValue], key: str,
self.assertIsNotNone(actual_value)
self.assertEqual(expected_value, actual_value.int_value)

def _assert_float_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, expected_value: float) -> None:
self.assertIn(key, attributes_dict)
actual_value: AnyValue = attributes_dict[key]
self.assertIsNotNone(actual_value)
self.assertEqual(expected_value, actual_value.double_value)

def _assert_match_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, pattern: str) -> None:
self.assertIn(key, attributes_dict)
actual_value: AnyValue = attributes_dict[key]
Expand Down Expand Up @@ -237,5 +243,5 @@ def _is_valid_regex(self, pattern: str) -> bool:
try:
re.compile(pattern)
return True
except re.error:
except (re.error, StopIteration, RuntimeError, KeyError):
return False
Loading

0 comments on commit 642427e

Please sign in to comment.