From a8b5520dfe1dcbf8e35cbee597e7ddfb7be90b48 Mon Sep 17 00:00:00 2001
From: pm3310
Date: Wed, 14 Feb 2024 11:42:27 +0000
Subject: [PATCH] Sagemaker image generation
---
sagify/llm_gateway/api/v1/endpoints/images.py | 4 +-
sagify/llm_gateway/providers/aws/sagemaker.py | 57 ++++++++++++++++++-
sagify/llm_gateway/schemas/chat.py | 2 +-
sagify/llm_gateway/schemas/images.py | 17 ++++--
4 files changed, 72 insertions(+), 8 deletions(-)
diff --git a/sagify/llm_gateway/api/v1/endpoints/images.py b/sagify/llm_gateway/api/v1/endpoints/images.py
index 71ba7a1..260a580 100644
--- a/sagify/llm_gateway/api/v1/endpoints/images.py
+++ b/sagify/llm_gateway/api/v1/endpoints/images.py
@@ -14,7 +14,9 @@ async def create(request: CreateImageDTO):
model=request.model,
prompt=request.prompt,
n=request.n,
- size=request.size
+ width=request.width,
+ height=request.height,
+ seed=request.seed,
)
response = await images.generations(parsed_message)
diff --git a/sagify/llm_gateway/providers/aws/sagemaker.py b/sagify/llm_gateway/providers/aws/sagemaker.py
index 6714d90..1e875f5 100644
--- a/sagify/llm_gateway/providers/aws/sagemaker.py
+++ b/sagify/llm_gateway/providers/aws/sagemaker.py
@@ -9,7 +9,7 @@
from sagify.llm_gateway.api.v1.exceptions import InternalServerError
from sagify.llm_gateway.schemas.chat import CreateCompletionDTO, ResponseCompletionDTO
from sagify.llm_gateway.schemas.embeddings import CreateEmbeddingDTO, ResponseEmbeddingDTO
-from sagify.llm_gateway.schemas.images import CreateImageDTO
+from sagify.llm_gateway.schemas.images import CreateImageDTO, ResponseImageDTO
from sagify.llm_gateway.schemas.chat import ChoiceItem, MessageItem
logger = structlog.get_logger()
@@ -53,7 +53,60 @@ async def embeddings(self, embedding_input: CreateEmbeddingDTO):
raise InternalServerError()
async def generations(self, image_input: CreateImageDTO):
- pass
+ request = {
+ "model": image_input.model,
+ "prompt": image_input.prompt,
+ "n": image_input.n,
+ "width": image_input.width,
+ "height": image_input.height,
+ "seed": image_input.seed,
+ "response_format": image_input.response_format
+ }
+ try:
+ return self._invoke_image_creation_endpoint(**request)
+ except Exception as e:
+ logger.error(e)
+ raise InternalServerError()
+
+ def _invoke_image_creation_endpoint(
+ self,
+ model,
+ prompt,
+ n,
+ width,
+ height,
+ seed,
+ response_format
+ ):
+ payload = {
+ "prompt": prompt,
+ "width": width,
+ "height": height,
+ "num_images_per_prompt": n,
+ "num_inference_steps": 50,
+ "guidance_scale": 7.5,
+ "seed": seed,
+ }
+ response = self.sagemaker_runtime_client.invoke_endpoint(
+ EndpointName=model,
+ Body=json.dumps(payload),
+ ContentType="application/json",
+ CustomAttributes='accept_eula=true',
+ Accept="application/json;jpeg"
+ )
+ response_dict = json.loads(response['Body'].read().decode('utf-8'))
+
+ return ResponseImageDTO(
+ provider='sagemaker',
+ model=model,
+ created=int(time.time()),
+ data=[
+ {
+ 'url': None,
+ 'b64_json': _b64_json
+ } for _b64_json in response_dict['generated_images']
+ ]
+ )
def _invoke_embeddings_endpoint(self, model, input):
"""
diff --git a/sagify/llm_gateway/schemas/chat.py b/sagify/llm_gateway/schemas/chat.py
index 2fb5b21..33be095 100644
--- a/sagify/llm_gateway/schemas/chat.py
+++ b/sagify/llm_gateway/schemas/chat.py
@@ -17,7 +17,7 @@ class MessageItem(BaseModel):
class CreateCompletionDTO(BaseModel):
- provider: Optional[str]
+ provider: str
model: str
messages: List[MessageItem]
temperature: float
diff --git a/sagify/llm_gateway/schemas/images.py b/sagify/llm_gateway/schemas/images.py
index d7e6e04..3674e8f 100644
--- a/sagify/llm_gateway/schemas/images.py
+++ b/sagify/llm_gateway/schemas/images.py
@@ -1,17 +1,26 @@
-from typing import List
+from enum import Enum
+from typing import List, Optional, Union
from pydantic import BaseModel
+class ResponseFormat(str, Enum):
+ URL = "url"
+ B64_JSON = "b64_json"
+
class CreateImageDTO(BaseModel):
provider: str
model: str
- prompt: str
+ prompt: Union[List[str], str]
n: int
- size: str
+ width: int
+ height: int
+ seed: Optional[int]
+ response_format: Optional[ResponseFormat] = 'url'
class DataItem(BaseModel):
- url: str
+ url: Optional[str]
+ b64_json: Optional[str]
class ResponseImageDTO(BaseModel):