Skip to content

Commit

Permalink
feat(sagemaker): support batch-transform (#6055)
Browse files Browse the repository at this point in the history
Co-authored-by: Joan Fontanals <[email protected]>
  • Loading branch information
deepankarm and JoanFM authored Sep 28, 2023
1 parent 4ea8bb5 commit 67c83c2
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 72 deletions.
53 changes: 27 additions & 26 deletions jina/orchestrate/deployments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Type, Union, overload

from hubble.executor.helper import replace_secret_of_hub_uri
from hubble.executor.hubio import HubIO
from rich import print
from rich.panel import Panel

Expand Down Expand Up @@ -54,7 +53,6 @@
from jina.orchestrate.pods.factory import PodFactory
from jina.parsers import set_deployment_parser, set_gateway_parser
from jina.parsers.helper import _update_gateway_args
from jina.serve.networking import GrpcConnectionPool
from jina.serve.networking.utils import host_is_local, in_docker

WRAPPED_SLICE_BASE = r'\[[-\d:]+\]'
Expand Down Expand Up @@ -103,7 +101,7 @@ def _call_add_voters(leader, voters, replica_ids, name, event_signal=None):
logger.success(
f'Replica-{str(replica_id)} successfully added as voter with address {voter_address} to leader at {leader}'
)
logger.debug(f'Adding voters to leader finished')
logger.debug('Adding voters to leader finished')
if event_signal:
event_signal.set()

Expand Down Expand Up @@ -138,7 +136,7 @@ def _add_voter_to_leader(self):
voter_addresses = [pod.runtime_ctrl_address for pod in self._pods[1:]]
replica_ids = [pod.args.replica_id for pod in self._pods[1:]]
event_signal = multiprocessing.Event()
self.logger.debug(f'Starting process to call Add Voters')
self.logger.debug('Starting process to call Add Voters')
process = multiprocessing.Process(
target=_call_add_voters,
kwargs={
Expand All @@ -159,19 +157,19 @@ def _add_voter_to_leader(self):
else:
time.sleep(1.0)
if properly_closed:
self.logger.debug(f'Add Voters process finished')
self.logger.debug('Add Voters process finished')
process.terminate()
else:
self.logger.error(f'Add Voters process did not finish successfully')
self.logger.error('Add Voters process did not finish successfully')
process.kill()
self.logger.debug(f'Add Voters process finished')
self.logger.debug('Add Voters process finished')

async def _async_add_voter_to_leader(self):
leader_address = f'{self._pods[0].runtime_ctrl_address}'
voter_addresses = [pod.runtime_ctrl_address for pod in self._pods[1:]]
replica_ids = [pod.args.replica_id for pod in self._pods[1:]]
event_signal = multiprocessing.Event()
self.logger.debug(f'Starting process to call Add Voters')
self.logger.debug('Starting process to call Add Voters')
process = multiprocessing.Process(
target=_call_add_voters,
kwargs={
Expand All @@ -192,10 +190,10 @@ async def _async_add_voter_to_leader(self):
else:
await asyncio.sleep(1.0)
if properly_closed:
self.logger.debug(f'Add Voters process finished')
self.logger.debug('Add Voters process finished')
process.terminate()
else:
self.logger.error(f'Add Voters process did not finish successfully')
self.logger.error('Add Voters process did not finish successfully')
process.kill()

@property
Expand All @@ -214,23 +212,23 @@ def join(self):
pod.join()

def wait_start_success(self):
self.logger.debug(f'Waiting for ReplicaSet to start successfully')
self.logger.debug('Waiting for ReplicaSet to start successfully')
for pod in self._pods:
pod.wait_start_success()
# should this be done only when the cluster is started ?
if self._pods[0].args.stateful:
self._add_voter_to_leader()
self.logger.debug(f'ReplicaSet started successfully')
self.logger.debug('ReplicaSet started successfully')

async def async_wait_start_success(self):
self.logger.debug(f'Waiting for ReplicaSet to start successfully')
self.logger.debug('Waiting for ReplicaSet to start successfully')
await asyncio.gather(
*[pod.async_wait_start_success() for pod in self._pods]
)
# should this be done only when the cluster is started ?
if self._pods[0].args.stateful:
await self._async_add_voter_to_leader()
self.logger.debug(f'ReplicaSet started successfully')
self.logger.debug('ReplicaSet started successfully')

def __enter__(self):
for _args in self.args:
Expand Down Expand Up @@ -481,16 +479,19 @@ def __init__(
if self.args.provider == ProviderType.SAGEMAKER:
if self._gateway_kwargs.get('port', 0) == 8080:
raise ValueError(
f'Port 8080 is reserved for Sagemaker deployment. Please use another port'
'Port 8080 is reserved for Sagemaker deployment. '
'Please use another port'
)
if self.args.port != [8080]:
warnings.warn(
f'Port is changed to 8080 for Sagemaker deployment. Port {self.args.port} is ignored'
'Port is changed to 8080 for Sagemaker deployment. '
f'Port {self.args.port} is ignored'
)
self.args.port = [8080]
if self.args.protocol != [ProtocolType.HTTP]:
warnings.warn(
f'Protocol is changed to HTTP for Sagemaker deployment. Protocol {self.args.protocol} is ignored'
'Protocol is changed to HTTP for Sagemaker deployment. '
f'Protocol {self.args.protocol} is ignored'
)
self.args.protocol = [ProtocolType.HTTP]
if self._include_gateway and ProtocolType.HTTP in self.args.protocol:
Expand Down Expand Up @@ -529,10 +530,10 @@ def __init__(

if self.args.stateful and (is_windows_os or (is_mac_os and is_37)):
if is_windows_os:
raise RuntimeError(f'Stateful feature is not available on Windows')
raise RuntimeError('Stateful feature is not available on Windows')
if is_mac_os:
raise RuntimeError(
f'Stateful feature when running on MacOS requires Python3.8 or newer version'
'Stateful feature when running on MacOS requires Python3.8 or newer version'
)
if self.args.stateful and (
ProtocolType.WEBSOCKET in self.args.protocol
Expand Down Expand Up @@ -805,7 +806,7 @@ def _copy_to_head_args(args: Namespace) -> Namespace:
if args.name:
_head_args.name = f'{args.name}/head'
else:
_head_args.name = f'head'
_head_args.name = 'head'

return _head_args

Expand Down Expand Up @@ -1209,7 +1210,7 @@ async def async_wait_start_success(self) -> None:
coros.append(self.shards[shard_id].async_wait_start_success())

await asyncio.gather(*coros)
self.logger.debug(f'Deployment started successfully')
self.logger.debug('Deployment started successfully')
except:
self.close()
raise
Expand Down Expand Up @@ -1374,7 +1375,7 @@ def _set_pod_args(self) -> Dict[int, List[Namespace]]:
peer_ports = peer_ports_all_shards.get(str(shard_id), [])
if len(peer_ports) > 0 and len(peer_ports) != replicas:
raise ValueError(
f'peer-ports argument does not match number of replicas, it will be ignored'
'peer-ports argument does not match number of replicas, it will be ignored'
)
elif len(peer_ports) == 0:
peer_ports = [random_port() for _ in range(replicas)]
Expand Down Expand Up @@ -1506,12 +1507,12 @@ def _parse_base_deployment_args(self, args):

if self.args.stateful and self.args.replicas in [1, 2]:
self.logger.debug(
f'Stateful Executor is not recommended to be used less than 3 replicas'
'Stateful Executor is not recommended to be used less than 3 replicas'
)

if self.args.stateful and self.args.workspace is None:
raise ValueError(
f'Stateful Executors need to be provided `workspace` when used in a Deployment'
'Stateful Executors need to be provided `workspace` when used in a Deployment'
)

# a gateway has no heads and uses
Expand Down Expand Up @@ -1568,7 +1569,7 @@ def _mermaid_str(self) -> List[str]:
mermaid_graph = []
secret = '&ltsecret&gt'
if self.role != DeploymentRoleType.GATEWAY and not self.external:
mermaid_graph = [f'subgraph {self.name};', f'\ndirection LR;\n']
mermaid_graph = [f'subgraph {self.name};', '\ndirection LR;\n']

uses_before_name = (
self.uses_before_args.name
Expand Down Expand Up @@ -1596,7 +1597,7 @@ def _mermaid_str(self) -> List[str]:
shard_names.append(shard_name)
shard_mermaid_graph = [
f'subgraph {shard_name};',
f'\ndirection TB;\n',
'\ndirection TB;\n',
]
names = [
args.name for args in pod_args
Expand Down
71 changes: 62 additions & 9 deletions jina/serve/runtimes/worker/http_sagemaker_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ def get_fastapi_app(
:return: fastapi app
"""
with ImportExtensions(required=True):
from fastapi import FastAPI, Response, HTTPException
import pydantic
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from pydantic.config import BaseConfig, inherit_config

import os

from pydantic import BaseModel, Field
from pydantic.config import BaseConfig, inherit_config

from jina.proto import jina_pb2
from jina.serve.runtimes.gateway.models import _to_camel_case

Expand Down Expand Up @@ -83,26 +82,38 @@ def add_post_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Endpoint {endpoint_path}',
response_model=output_model,
response_model=Union[output_model, List[output_model]],
response_class=DocArrayResponse,
)

@app.api_route(**app_kwargs)
async def post(body: input_model, response: Response):
def is_valid_csv(content: str) -> bool:
import csv
from io import StringIO

try:
f = StringIO(content)
reader = csv.DictReader(f)
for _ in reader:
pass

return True
except Exception:
return False

async def process(body) -> output_model:
req = DataRequest()
if body.header is not None:
req.header.request_id = body.header.request_id

if body.parameters is not None:
req.parameters = body.parameters
req.header.exec_endpoint = endpoint_path
req.document_array_cls = DocList[input_doc_model]

data = body.data
if isinstance(data, list):
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_list_model](data)
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_list_model]([data])
if body.header is None:
req.header.request_id = req.docs[0].id
Expand All @@ -115,6 +126,48 @@ async def post(body: input_model, response: Response):
else:
return output_model(data=resp.docs, parameters=resp.parameters)

@app.api_route(**app_kwargs)
async def post(request: Request):
content_type = request.headers.get('content-type')
if content_type == 'application/json':
json_body = await request.json()
return await process(input_model(**json_body))

elif content_type in ('text/csv', 'application/csv'):
bytes_body = await request.body()
csv_body = bytes_body.decode('utf-8')
if not is_valid_csv(csv_body):
raise HTTPException(
status_code=400,
detail='Invalid CSV input. Please check your input.',
)

# NOTE: Sagemaker only supports csv files without header, so we enforce
# the header by getting the field names from the input model.
# This will also enforce the order of the fields in the csv file.
# This also means, all fields in the input model must be present in the
# csv file including the optional ones.
field_names = [f for f in input_doc_list_model.__fields__]
data = []
for line in csv_body.splitlines():
fields = line.split(',')
if len(fields) != len(field_names):
raise HTTPException(
status_code=400,
detail=f'Invalid CSV format. Line {fields} doesn\'t match '
f'the expected field order {field_names}.',
)
data.append(input_doc_list_model(**dict(zip(field_names, fields))))

return await process(input_model(data=data))

else:
raise HTTPException(
status_code=400,
detail=f'Invalid content-type: {content_type}. '
f'Please use either application/json or text/csv.',
)

for endpoint, input_output_map in request_models_map.items():
if endpoint != '_jina_dry_run_':
input_doc_model = input_output_map['input']['model']
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/docarray_v2/sagemaker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM jinaai/jina:test-pip

COPY . /executor_root/

WORKDIR /executor_root/SampleExecutor

ENTRYPOINT ["jina", "executor", "--uses", "config.yml"]
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ class SampleExecutor(Executor):
def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseModel]:
ret = []
for doc in docs:
ret.append(EmbeddingResponseModel(embeddings=np.random.random((1, 64))))
ret.append(
EmbeddingResponseModel(id=doc.id, embeddings=np.random.random((1, 64)))
)
return DocList[EmbeddingResponseModel](ret)
10 changes: 10 additions & 0 deletions tests/integration/docarray_v2/sagemaker/invalid_input.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
abcd
efgh
ijkl
mnop
qrst
uvwx
yzab
cdef
ghij
klmn
Loading

0 comments on commit 67c83c2

Please sign in to comment.