Skip to content

Commit

Permalink
Image url for sagemaker image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
pm3310 committed Feb 14, 2024
1 parent f7c8441 commit 3b9b317
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions sagify/llm_gateway/providers/aws/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from PIL import Image
import base64
import json
from io import BytesIO
import os
import time
import uuid
Expand All @@ -9,7 +12,7 @@
from sagify.llm_gateway.api.v1.exceptions import InternalServerError
from sagify.llm_gateway.schemas.chat import CreateCompletionDTO, ResponseCompletionDTO
from sagify.llm_gateway.schemas.embeddings import CreateEmbeddingDTO, ResponseEmbeddingDTO
from sagify.llm_gateway.schemas.images import CreateImageDTO, ResponseImageDTO
from sagify.llm_gateway.schemas.images import CreateImageDTO, ResponseImageDTO, ResponseFormat
from sagify.llm_gateway.schemas.chat import ChoiceItem, MessageItem

logger = structlog.get_logger()
Expand All @@ -20,12 +23,14 @@ def __init__(self):
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
aws_region_name = os.environ.get("AWS_REGION_NAME")
self._bucket_name = os.environ.get("S3_BUCKET_NAME")
self.boto_session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name
)
self.sagemaker_runtime_client = self.boto_session.client('sagemaker-runtime')
self.s3_client = self.boto_session.client('s3')

async def completions(self, message: CreateCompletionDTO):
request = {
Expand Down Expand Up @@ -101,12 +106,40 @@ def _invoke_image_creation_endpoint(
model=model,
created=int(time.time()),
data=[
{
'url': None,
'b64_json': _b64_json
} for _b64_json in response_dict['generated_images']
self._prepare_image_item_response(
response_format, _base64_string
) for _base64_string in response_dict['generated_images']
]
)

def _prepare_image_item_response(self, response_format, base64_string):
if response_format == ResponseFormat.URL:
return {
'url':self._generated_image_url(self, base64_string),
}
else:
return {
'b64_json': base64_string
}

def _generated_image_url(self, base64_string):
# Decode the base64 string
img_data = base64.b64decode(base64_string)

# Create a PIL Image object
img = Image.open(BytesIO(img_data))

# Save the image to a BytesIO object
buffer = BytesIO()
img.save(buffer, format='PNG')
buffer.seek(0)

# Upload the image to S3
key = '{}.png'.format(str(uuid.uuid4()))
self.s3_client.upload_fileobj(buffer, self._bucket_name, key)

# Get the URL of the uploaded image
return f"https://{bucket_name}.s3.amazonaws.com/{key}"

def _invoke_embeddings_endpoint(self, model, input):
"""
Expand Down

0 comments on commit 3b9b317

Please sign in to comment.