From a8b5520dfe1dcbf8e35cbee597e7ddfb7be90b48 Mon Sep 17 00:00:00 2001 From: pm3310 Date: Wed, 14 Feb 2024 11:42:27 +0000 Subject: [PATCH] Sagemaker image generation --- sagify/llm_gateway/api/v1/endpoints/images.py | 4 +- sagify/llm_gateway/providers/aws/sagemaker.py | 57 ++++++++++++++++++- sagify/llm_gateway/schemas/chat.py | 2 +- sagify/llm_gateway/schemas/images.py | 17 ++++-- 4 files changed, 72 insertions(+), 8 deletions(-) diff --git a/sagify/llm_gateway/api/v1/endpoints/images.py b/sagify/llm_gateway/api/v1/endpoints/images.py index 71ba7a1..260a580 100644 --- a/sagify/llm_gateway/api/v1/endpoints/images.py +++ b/sagify/llm_gateway/api/v1/endpoints/images.py @@ -14,7 +14,9 @@ async def create(request: CreateImageDTO): model=request.model, prompt=request.prompt, n=request.n, - size=request.size + width=request.width, + height=request.height, + seed=request.seed, ) response = await images.generations(parsed_message) diff --git a/sagify/llm_gateway/providers/aws/sagemaker.py b/sagify/llm_gateway/providers/aws/sagemaker.py index 6714d90..1e875f5 100644 --- a/sagify/llm_gateway/providers/aws/sagemaker.py +++ b/sagify/llm_gateway/providers/aws/sagemaker.py @@ -9,7 +9,7 @@ from sagify.llm_gateway.api.v1.exceptions import InternalServerError from sagify.llm_gateway.schemas.chat import CreateCompletionDTO, ResponseCompletionDTO from sagify.llm_gateway.schemas.embeddings import CreateEmbeddingDTO, ResponseEmbeddingDTO -from sagify.llm_gateway.schemas.images import CreateImageDTO +from sagify.llm_gateway.schemas.images import CreateImageDTO, ResponseImageDTO from sagify.llm_gateway.schemas.chat import ChoiceItem, MessageItem logger = structlog.get_logger() @@ -53,7 +53,60 @@ async def embeddings(self, embedding_input: CreateEmbeddingDTO): raise InternalServerError() async def generations(self, image_input: CreateImageDTO): - pass + request = { + "model": image_input.model, + "prompt": image_input.prompt, + "n": image_input.n, + "width": image_input.width, + "height": image_input.height, + "seed": image_input.seed, + "response_format": image_input.response_format + } + try: + return self._invoke_image_creation_endpoint(**request) + except Exception as e: + logger.error(e) + raise InternalServerError() + + def _invoke_image_creation_endpoint( + self, + model, + prompt, + n, + width, + height, + seed, + response_format + ): + payload = { + "prompt": prompt, + "width": width, + "height": height, + "num_images_per_prompt": n, + "num_inference_steps": 50, + "guidance_scale": 7.5, + "seed": seed, + } + response = self.sagemaker_runtime_client.invoke_endpoint( + EndpointName=model, + Body=json.dumps(payload), + ContentType="application/json", + CustomAttributes='accept_eula=true', + Accept="application/json;jpeg" + ) + response_dict = json.loads(response['Body'].read().decode('utf-8')) + + return ResponseImageDTO( + provider='sagemaker', + model=model, + created=int(time.time()), + data=[ + { + 'url': None, + 'b64_json': _b64_json + } for _b64_json in response_dict['generated_images'] + ] + ) def _invoke_embeddings_endpoint(self, model, input): """ diff --git a/sagify/llm_gateway/schemas/chat.py b/sagify/llm_gateway/schemas/chat.py index 2fb5b21..33be095 100644 --- a/sagify/llm_gateway/schemas/chat.py +++ b/sagify/llm_gateway/schemas/chat.py @@ -17,7 +17,7 @@ class MessageItem(BaseModel): class CreateCompletionDTO(BaseModel): - provider: Optional[str] + provider: str model: str messages: List[MessageItem] temperature: float diff --git a/sagify/llm_gateway/schemas/images.py b/sagify/llm_gateway/schemas/images.py index d7e6e04..3674e8f 100644 --- a/sagify/llm_gateway/schemas/images.py +++ b/sagify/llm_gateway/schemas/images.py @@ -1,17 +1,26 @@ -from typing import List +from enum import Enum +from typing import List, Optional, Union from pydantic import BaseModel +class ResponseFormat(str, Enum): + URL = "url" + B64_JSON = "b64_json" + class CreateImageDTO(BaseModel): provider: str model: str - prompt: str + prompt: Union[List[str], str] n: int - size: str + width: int + height: int + seed: Optional[int] + response_format: Optional[ResponseFormat] = 'url' class DataItem(BaseModel): - url: str + url: Optional[str] + b64_json: Optional[str] class ResponseImageDTO(BaseModel):