From 8a58dfa90cdc3d26f457cddf3b6a1bd8097cffe1 Mon Sep 17 00:00:00 2001 From: Zac Li Date: Thu, 29 Feb 2024 21:41:15 +0800 Subject: [PATCH] fix: batch transform update for sagemaker reranker integration (#6145) Co-authored-by: Joan Martinez --- .github/workflows/cd.yml | 2 +- .github/workflows/ci.yml | 2 +- jina/serve/executors/__init__.py | 5 +- .../runtimes/worker/http_sagemaker_app.py | 51 +++- .../SampleRerankerExecutor/README.md | 2 + .../SampleRerankerExecutor/config.yml | 8 + .../SampleRerankerExecutor/executor.py | 53 ++++ .../SampleRerankerExecutor/requirements.txt | 0 .../docarray_v2/sagemaker/test_embedding.py | 235 +++++++++++++++++ .../docarray_v2/sagemaker/test_reranking.py | 91 +++++++ .../docarray_v2/sagemaker/test_sagemaker.py | 243 ------------------ .../sagemaker/valid_reranker_input.csv | 2 + 12 files changed, 447 insertions(+), 247 deletions(-) create mode 100644 tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/README.md create mode 100644 tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/config.yml create mode 100644 tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/executor.py create mode 100644 tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/requirements.txt create mode 100644 tests/integration/docarray_v2/sagemaker/test_embedding.py create mode 100644 tests/integration/docarray_v2/sagemaker/test_reranking.py delete mode 100644 tests/integration/docarray_v2/sagemaker/test_sagemaker.py create mode 100644 tests/integration/docarray_v2/sagemaker/valid_reranker_input.csv diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 6362dab74d535..4cef01bea8fdc 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -151,7 +151,7 @@ jobs: pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py - pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker/test_sagemaker.py + pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/docker echo "flag it as jina for codeoverage" echo "codecov_flag=jina" >> $GITHUB_OUTPUT diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d0738c7359eae..312eb6b655ea1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -485,7 +485,7 @@ jobs: pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py - pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker/test_sagemaker.py + pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/docker echo "flag it as jina for codeoverage" echo "codecov_flag=jina" >> $GITHUB_OUTPUT diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index fefdcf981360c..a14d64fe22bb6 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -68,7 +68,10 @@ def is_pydantic_model(annotation: Type) -> bool: :param annotation: The annotation from which to extract PydantiModel. :return: boolean indicating if a Pydantic model is inside the annotation """ - from typing import get_args, get_origin + try: + from typing import get_args, get_origin + except ImportError: + from typing_extensions import get_args, get_origin from pydantic import BaseModel diff --git a/jina/serve/runtimes/worker/http_sagemaker_app.py b/jina/serve/runtimes/worker/http_sagemaker_app.py index db293487bf348..12ef0d2bb476a 100644 --- a/jina/serve/runtimes/worker/http_sagemaker_app.py +++ b/jina/serve/runtimes/worker/http_sagemaker_app.py @@ -76,7 +76,15 @@ def add_post_route( input_doc_list_model=None, output_doc_list_model=None, ): + import json + from typing import List, Type, Union + try: + from typing import get_args, get_origin + except ImportError: + from typing_extensions import get_args, get_origin + from docarray.base_doc.docarray_response import DocArrayResponse + from pydantic import BaseModel, ValidationError, parse_obj_as app_kwargs = dict( path=f'/{endpoint_path.strip("/")}', @@ -145,6 +153,47 @@ async def post(request: Request): detail='Invalid CSV input. Please check your input.', ) + def construct_model_from_line( + model: Type[BaseModel], line: List[str] + ) -> BaseModel: + parsed_fields = {} + model_fields = model.__fields__ + + for field_str, (field_name, field_info) in zip( + line, model_fields.items() + ): + field_type = field_info.outer_type_ + + # Handle Union types by attempting to arse each potential type + if get_origin(field_type) is Union: + for possible_type in get_args(field_type): + if possible_type is str: + parsed_fields[field_name] = field_str + break + else: + try: + parsed_fields[field_name] = parse_obj_as( + possible_type, json.loads(field_str) + ) + break + except (json.JSONDecodeError, ValidationError): + continue + # Handle list of nested models + elif get_origin(field_type) is list: + list_item_type = get_args(field_type)[0] + parsed_list = json.loads(field_str) + if issubclass(list_item_type, BaseModel): + parsed_fields[field_name] = parse_obj_as( + List[list_item_type], parsed_list + ) + else: + parsed_fields[field_name] = parsed_list + # Handle direct assignment for basic types + else: + parsed_fields[field_name] = field_info.type_(field_str) + + return model(**parsed_fields) + # NOTE: Sagemaker only supports csv files without header, so we enforce # the header by getting the field names from the input model. # This will also enforce the order of the fields in the csv file. @@ -165,7 +214,7 @@ async def post(request: Request): detail=f'Invalid CSV format. Line {line} doesn\'t match ' f'the expected field order {field_names}.', ) - data.append(input_doc_list_model(**dict(zip(field_names, line)))) + data.append(construct_model_from_line(input_doc_list_model, line)) return await process(input_model(data=data)) diff --git a/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/README.md b/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/README.md new file mode 100644 index 0000000000000..75ae6efad805e --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/README.md @@ -0,0 +1,2 @@ +# SampleRerankerExecutor + diff --git a/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/config.yml b/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/config.yml new file mode 100644 index 0000000000000..18d2799485ce1 --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/config.yml @@ -0,0 +1,8 @@ +jtype: SampleRerankerExecutor +py_modules: + - executor.py +metas: + name: SampleRerankerExecutor + description: + url: + keywords: [] diff --git a/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/executor.py b/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/executor.py new file mode 100644 index 0000000000000..7c38e3468433e --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/executor.py @@ -0,0 +1,53 @@ +from docarray import BaseDoc, DocList +from pydantic import Field +from typing import Union, Optional, List +from jina import Executor, requests + + +class TextDoc(BaseDoc): + text: str = Field(description="The text of the document", default="") + + +class RerankerInput(BaseDoc): + query: Union[str, TextDoc] + + documents: List[TextDoc] + + top_n: Optional[int] + + +class RankedObjectOutput(BaseDoc): + index: int + document: Optional[TextDoc] + + relevance_score: float + + +class RankedOutput(BaseDoc): + results: DocList[RankedObjectOutput] + + +class SampleRerankerExecutor(Executor): + @requests(on="/rerank") + def foo(self, docs: DocList[RerankerInput], **kwargs) -> DocList[RankedOutput]: + ret = [] + for doc in docs: + ret.append( + RankedOutput( + results=[ + RankedObjectOutput( + id=doc.id, + index=0, + document=TextDoc(text="first result"), + relevance_score=-1, + ), + RankedObjectOutput( + id=doc.id, + index=1, + document=TextDoc(text="second result"), + relevance_score=-2, + ), + ] + ) + ) + return DocList[RankedOutput](ret) diff --git a/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/requirements.txt b/tests/integration/docarray_v2/sagemaker/SampleRerankerExecutor/requirements.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/docarray_v2/sagemaker/test_embedding.py b/tests/integration/docarray_v2/sagemaker/test_embedding.py new file mode 100644 index 0000000000000..eb86a6e2178b2 --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/test_embedding.py @@ -0,0 +1,235 @@ +import csv +import io +import os +import time + +import pytest +import requests + +from jina import Deployment +from jina.helper import random_port +from jina.orchestrate.pods import Pod +from jina.parsers import set_pod_parser + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +sagemaker_port = 8080 + + +@pytest.fixture +def replica_docker_image_built(): + import docker + + client = docker.from_env() + client.images.build(path=os.path.join(cur_dir), tag='sampler-executor') + client.close() + yield + time.sleep(2) + client = docker.from_env() + client.containers.prune() + + +def test_provider_sagemaker_pod_inference(): + args, _ = set_pod_parser().parse_known_args( + [ + '--uses', + os.path.join( + os.path.dirname(__file__), "SampleExecutor", "config.yml" + ), + '--provider', + 'sagemaker', + 'serve', # This is added by sagemaker + ] + ) + with Pod(args): + # Test the `GET /ping` endpoint (added by jina for sagemaker) + resp = requests.get(f'http://localhost:{sagemaker_port}/ping') + assert resp.status_code == 200 + assert resp.json() == {} + + # Test the `POST /invocations` endpoint for inference + # Note: this endpoint is not implemented in the sample executor + resp = requests.post( + f'http://localhost:{sagemaker_port}/invocations', + json={ + 'data': [ + {'text': 'hello world'}, + ] + }, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json['data']) == 1 + assert len(resp_json['data'][0]['embeddings'][0]) == 64 + + +@pytest.mark.parametrize( + "filename", + [ + "valid_input_1.csv", + "valid_input_2.csv", + ], +) +def test_provider_sagemaker_pod_batch_transform_valid(filename): + args, _ = set_pod_parser().parse_known_args( + [ + '--uses', + os.path.join( + os.path.dirname(__file__), "SampleExecutor", "config.yml" + ), + '--provider', + 'sagemaker', + 'serve', # This is added by sagemaker + ] + ) + with Pod(args): + # Test `POST /invocations` endpoint for batch-transform with valid input + texts = [] + with open(os.path.join(os.path.dirname(__file__), filename), "r") as f: + csv_data = f.read() + + for line in csv.reader( + io.StringIO(csv_data), + delimiter=",", + quoting=csv.QUOTE_NONE, + escapechar="\\", + ): + texts.append(line[1]) + + resp = requests.post( + f"http://localhost:{sagemaker_port}/invocations", + headers={ + "accept": "application/json", + "content-type": "text/csv", + }, + data=csv_data, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["data"]) == 10 + for idx, d in enumerate(resp_json["data"]): + assert d["text"] == texts[idx] + assert len(d["embeddings"][0]) == 64 + + +def test_provider_sagemaker_pod_batch_transform_invalid(): + args, _ = set_pod_parser().parse_known_args( + [ + '--uses', + os.path.join( + os.path.dirname(__file__), "SampleExecutor", "config.yml" + ), + '--provider', + 'sagemaker', + 'serve', # This is added by sagemaker + ] + ) + with Pod(args): + # Test `POST /invocations` endpoint for batch-transform with invalid input + with open( + os.path.join(os.path.dirname(__file__), 'invalid_input.csv'), 'r' + ) as f: + csv_data = f.read() + + resp = requests.post( + f'http://localhost:{sagemaker_port}/invocations', + headers={ + 'accept': 'application/json', + 'content-type': 'text/csv', + }, + data=csv_data, + ) + assert resp.status_code == 400 + assert ( + resp.json()['detail'] + == "Invalid CSV format. Line ['abcd'] doesn't match the expected field " + "order ['id', 'text']." + ) + + +def test_provider_sagemaker_deployment_inference(): + dep_port = random_port() + with Deployment(uses=os.path.join( + os.path.dirname(__file__), "SampleExecutor", "config.yml" + ), provider='sagemaker', port=dep_port): + # Test the `GET /ping` endpoint (added by jina for sagemaker) + rsp = requests.get(f'http://localhost:{dep_port}/ping') + assert rsp.status_code == 200 + assert rsp.json() == {} + + # Test the `POST /invocations` endpoint + # Note: this endpoint is not implemented in the sample executor + rsp = requests.post( + f'http://localhost:{dep_port}/invocations', + json={ + 'data': [ + {'text': 'hello world'}, + ] + }, + ) + assert rsp.status_code == 200 + resp_json = rsp.json() + assert len(resp_json['data']) == 1 + assert len(resp_json['data'][0]['embeddings'][0]) == 64 + + +def test_provider_sagemaker_deployment_inference_docker(replica_docker_image_built): + dep_port = random_port() + with Deployment( + uses='docker://sampler-executor', provider='sagemaker', port=dep_port + ): + # Test the `GET /ping` endpoint (added by jina for sagemaker) + rsp = requests.get(f'http://localhost:{dep_port}/ping') + assert rsp.status_code == 200 + assert rsp.json() == {} + + # Test the `POST /invocations` endpoint + # Note: this endpoint is not implemented in the sample executor + rsp = requests.post( + f'http://localhost:{dep_port}/invocations', + json={ + 'data': [ + {'text': 'hello world'}, + ] + }, + ) + assert rsp.status_code == 200 + resp_json = rsp.json() + assert len(resp_json['data']) == 1 + assert len(resp_json['data'][0]['embeddings'][0]) == 64 + + +@pytest.mark.skip('Sagemaker with Deployment for batch-transform is not supported yet') +def test_provider_sagemaker_deployment_batch(): + dep_port = random_port() + with Deployment(uses=os.path.join( + os.path.dirname(__file__), "SampleExecutor", "config.yml" + ), provider='sagemaker', port=dep_port): + # Test the `POST /invocations` endpoint for batch-transform + with open( + os.path.join(os.path.dirname(__file__), 'valid_input.csv'), 'r' + ) as f: + csv_data = f.read() + + rsp = requests.post( + f'http://localhost:{dep_port}/invocations', + headers={ + 'accept': 'application/json', + 'content-type': 'text/csv', + }, + data=csv_data, + ) + assert rsp.status_code == 200 + resp_json = rsp.json() + assert len(resp_json['data']) == 10 + for d in resp_json['data']: + assert len(d['embeddings'][0]) == 64 + + +def test_provider_sagemaker_deployment_wrong_port(): + # Sagemaker executor would start on 8080. + # If we use the same port for deployment, it should raise an error. + with pytest.raises(ValueError): + with Deployment(uses=os.path.join( + os.path.dirname(__file__), "SampleExecutor", "config.yml" + ), provider='sagemaker', port=8080): + pass diff --git a/tests/integration/docarray_v2/sagemaker/test_reranking.py b/tests/integration/docarray_v2/sagemaker/test_reranking.py new file mode 100644 index 0000000000000..b5044f566639e --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/test_reranking.py @@ -0,0 +1,91 @@ +import csv +import io +import os + +import requests +from jina.orchestrate.pods import Pod +from jina.parsers import set_pod_parser + +sagemaker_port = 8080 + + +def test_provider_sagemaker_pod_inference(): + args, _ = set_pod_parser().parse_known_args( + [ + "--uses", + os.path.join( + os.path.dirname(__file__), "SampleRerankerExecutor", "config.yml" + ), + "--provider", + "sagemaker", + "serve", # This is added by sagemaker + ] + ) + with Pod(args): + # Test the `GET /ping` endpoint (added by jina for sagemaker) + resp = requests.get(f"http://localhost:{sagemaker_port}/ping") + assert resp.status_code == 200 + assert resp.json() == {} + + # Test the `POST /invocations` endpoint for inference + # Note: this endpoint is not implemented in the sample executor + resp = requests.post( + f"http://localhost:{sagemaker_port}/invocations", + json={ + "data": { + "documents": [ + {"text": "the dog is in the house"}, + {"text": "hey Peter"}, + ], + "query": "where is the dog", + "top_n": 2, + } + }, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["data"]) == 1 + assert resp_json["data"][0]["results"][0]["document"]["text"] == "first result" + + +def test_provider_sagemaker_pod_batch_transform_for_reranker_valid(): + args, _ = set_pod_parser().parse_known_args( + [ + "--uses", + os.path.join( + os.path.dirname(__file__), "SampleRerankerExecutor", "config.yml" + ), + "--provider", + "sagemaker", + "serve", # This is added by sagemaker + ] + ) + with Pod(args): + # Test `POST /invocations` endpoint for batch-transform with valid input + with open( + os.path.join(os.path.dirname(__file__), "valid_reranker_input.csv"), "r" + ) as f: + csv_data = f.read() + + text = [] + for line in csv.reader( + io.StringIO(csv_data), + delimiter=",", + quoting=csv.QUOTE_NONE, + escapechar="\\", + ): + text.append(line) + + resp = requests.post( + f"http://localhost:{sagemaker_port}/invocations", + headers={ + "accept": "application/json", + "content-type": "text/csv", + }, + data=csv_data, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["data"]) == 2 + assert resp_json["data"][0]["results"][0]["document"]["text"] == "first result" + assert resp_json["data"][1]["results"][1]["document"]["text"] == "second result" diff --git a/tests/integration/docarray_v2/sagemaker/test_sagemaker.py b/tests/integration/docarray_v2/sagemaker/test_sagemaker.py deleted file mode 100644 index 2e785464f0010..0000000000000 --- a/tests/integration/docarray_v2/sagemaker/test_sagemaker.py +++ /dev/null @@ -1,243 +0,0 @@ -import csv -import io -import os -import time -from contextlib import AbstractContextManager - -import pytest -import requests - -from jina import Deployment -from jina.helper import random_port -from jina.orchestrate.pods import Pod -from jina.parsers import set_pod_parser - -cur_dir = os.path.dirname(os.path.abspath(__file__)) -sagemaker_port = 8080 - - -@pytest.fixture -def replica_docker_image_built(): - import docker - - client = docker.from_env() - client.images.build(path=cur_dir, tag='sampler-executor') - client.close() - yield - time.sleep(2) - client = docker.from_env() - client.containers.prune() - - -class chdir(AbstractContextManager): - def __init__(self, path): - self.path = path - self._old_cwd = [] - - def __enter__(self): - self._old_cwd.append(os.getcwd()) - os.chdir(self.path) - - def __exit__(self, *excinfo): - os.chdir(self._old_cwd.pop()) - - -def test_provider_sagemaker_pod_inference(): - with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): - args, _ = set_pod_parser().parse_known_args( - [ - '--uses', - 'config.yml', - '--provider', - 'sagemaker', - 'serve', # This is added by sagemaker - ] - ) - with Pod(args): - # Test the `GET /ping` endpoint (added by jina for sagemaker) - resp = requests.get(f'http://localhost:{sagemaker_port}/ping') - assert resp.status_code == 200 - assert resp.json() == {} - - # Test the `POST /invocations` endpoint for inference - # Note: this endpoint is not implemented in the sample executor - resp = requests.post( - f'http://localhost:{sagemaker_port}/invocations', - json={ - 'data': [ - {'text': 'hello world'}, - ] - }, - ) - assert resp.status_code == 200 - resp_json = resp.json() - assert len(resp_json['data']) == 1 - assert len(resp_json['data'][0]['embeddings'][0]) == 64 - - -@pytest.mark.parametrize( - "filename", - [ - "valid_input_1.csv", - "valid_input_2.csv", - ], -) -def test_provider_sagemaker_pod_batch_transform_valid(filename): - with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): - args, _ = set_pod_parser().parse_known_args( - [ - '--uses', - 'config.yml', - '--provider', - 'sagemaker', - 'serve', # This is added by sagemaker - ] - ) - with Pod(args): - # Test `POST /invocations` endpoint for batch-transform with valid input - texts = [] - with open(os.path.join(os.path.dirname(__file__), filename), "r") as f: - csv_data = f.read() - - for line in csv.reader( - io.StringIO(csv_data), - delimiter=",", - quoting=csv.QUOTE_NONE, - escapechar="\\", - ): - texts.append(line[1]) - - resp = requests.post( - f"http://localhost:{sagemaker_port}/invocations", - headers={ - "accept": "application/json", - "content-type": "text/csv", - }, - data=csv_data, - ) - assert resp.status_code == 200 - resp_json = resp.json() - assert len(resp_json["data"]) == 10 - for idx, d in enumerate(resp_json["data"]): - assert d["text"] == texts[idx] - assert len(d["embeddings"][0]) == 64 - - -def test_provider_sagemaker_pod_batch_transform_invalid(): - with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): - args, _ = set_pod_parser().parse_known_args( - [ - '--uses', - 'config.yml', - '--provider', - 'sagemaker', - 'serve', # This is added by sagemaker - ] - ) - with Pod(args): - # Test `POST /invocations` endpoint for batch-transform with invalid input - with open( - os.path.join(os.path.dirname(__file__), 'invalid_input.csv'), 'r' - ) as f: - csv_data = f.read() - - resp = requests.post( - f'http://localhost:{sagemaker_port}/invocations', - headers={ - 'accept': 'application/json', - 'content-type': 'text/csv', - }, - data=csv_data, - ) - assert resp.status_code == 400 - assert ( - resp.json()['detail'] - == "Invalid CSV format. Line ['abcd'] doesn't match the expected field " - "order ['id', 'text']." - ) - - -def test_provider_sagemaker_deployment_inference(): - with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): - dep_port = random_port() - with Deployment(uses='config.yml', provider='sagemaker', port=dep_port): - # Test the `GET /ping` endpoint (added by jina for sagemaker) - rsp = requests.get(f'http://localhost:{dep_port}/ping') - assert rsp.status_code == 200 - assert rsp.json() == {} - - # Test the `POST /invocations` endpoint - # Note: this endpoint is not implemented in the sample executor - rsp = requests.post( - f'http://localhost:{dep_port}/invocations', - json={ - 'data': [ - {'text': 'hello world'}, - ] - }, - ) - assert rsp.status_code == 200 - resp_json = rsp.json() - assert len(resp_json['data']) == 1 - assert len(resp_json['data'][0]['embeddings'][0]) == 64 - - -def test_provider_sagemaker_deployment_inference_docker(replica_docker_image_built): - dep_port = random_port() - with Deployment( - uses='docker://sampler-executor', provider='sagemaker', port=dep_port - ): - # Test the `GET /ping` endpoint (added by jina for sagemaker) - rsp = requests.get(f'http://localhost:{dep_port}/ping') - assert rsp.status_code == 200 - assert rsp.json() == {} - - # Test the `POST /invocations` endpoint - # Note: this endpoint is not implemented in the sample executor - rsp = requests.post( - f'http://localhost:{dep_port}/invocations', - json={ - 'data': [ - {'text': 'hello world'}, - ] - }, - ) - assert rsp.status_code == 200 - resp_json = rsp.json() - assert len(resp_json['data']) == 1 - assert len(resp_json['data'][0]['embeddings'][0]) == 64 - - -@pytest.mark.skip('Sagemaker with Deployment for batch-transform is not supported yet') -def test_provider_sagemaker_deployment_batch(): - with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): - dep_port = random_port() - with Deployment(uses='config.yml', provider='sagemaker', port=dep_port): - # Test the `POST /invocations` endpoint for batch-transform - with open( - os.path.join(os.path.dirname(__file__), 'valid_input.csv'), 'r' - ) as f: - csv_data = f.read() - - rsp = requests.post( - f'http://localhost:{dep_port}/invocations', - headers={ - 'accept': 'application/json', - 'content-type': 'text/csv', - }, - data=csv_data, - ) - assert rsp.status_code == 200 - resp_json = rsp.json() - assert len(resp_json['data']) == 10 - for d in resp_json['data']: - assert len(d['embeddings'][0]) == 64 - - -def test_provider_sagemaker_deployment_wrong_port(): - # Sagemaker executor would start on 8080. - # If we use the same port for deployment, it should raise an error. - with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): - with pytest.raises(ValueError): - with Deployment(uses='config.yml', provider='sagemaker', port=8080): - pass diff --git a/tests/integration/docarray_v2/sagemaker/valid_reranker_input.csv b/tests/integration/docarray_v2/sagemaker/valid_reranker_input.csv new file mode 100644 index 0000000000000..5b18db7242a89 --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/valid_reranker_input.csv @@ -0,0 +1,2 @@ +07937cd63aed436b9e6c34a86783f55f,where's the dog,[{\"text\": \"the dog is in my house\"}\, {\"text\": \"the cat looks good\"}\, {\"text\": \"fish chips\"}],2 +286076113143433cb5f2655c5f70a30e,do you like cat,[{\"text\": \"cash is king\"}\, {\"text\": \"yes I do like cat\"}\, {\"text\": \"morning\"}], 3