Skip to content

Commit

Permalink
Sagemaker image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
pm3310 committed Feb 14, 2024
1 parent bb84064 commit a8b5520
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 8 deletions.
4 changes: 3 additions & 1 deletion sagify/llm_gateway/api/v1/endpoints/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ async def create(request: CreateImageDTO):
model=request.model,
prompt=request.prompt,
n=request.n,
size=request.size
width=request.width,
height=request.height,
seed=request.seed,
)

response = await images.generations(parsed_message)
Expand Down
57 changes: 55 additions & 2 deletions sagify/llm_gateway/providers/aws/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sagify.llm_gateway.api.v1.exceptions import InternalServerError
from sagify.llm_gateway.schemas.chat import CreateCompletionDTO, ResponseCompletionDTO
from sagify.llm_gateway.schemas.embeddings import CreateEmbeddingDTO, ResponseEmbeddingDTO
from sagify.llm_gateway.schemas.images import CreateImageDTO
from sagify.llm_gateway.schemas.images import CreateImageDTO, ResponseImageDTO
from sagify.llm_gateway.schemas.chat import ChoiceItem, MessageItem

logger = structlog.get_logger()
Expand Down Expand Up @@ -53,7 +53,60 @@ async def embeddings(self, embedding_input: CreateEmbeddingDTO):
raise InternalServerError()

async def generations(self, image_input: CreateImageDTO):
pass
request = {
"model": image_input.model,
"prompt": image_input.prompt,
"n": image_input.n,
"width": image_input.width,
"height": image_input.height,
"seed": image_input.seed,
"response_format": image_input.response_format
}
try:
return self._invoke_image_creation_endpoint(**request)
except Exception as e:
logger.error(e)
raise InternalServerError()

def _invoke_image_creation_endpoint(
self,
model,
prompt,
n,
width,
height,
seed,
response_format
):
payload = {
"prompt": prompt,
"width": width,
"height": height,
"num_images_per_prompt": n,
"num_inference_steps": 50,
"guidance_scale": 7.5,
"seed": seed,
}
response = self.sagemaker_runtime_client.invoke_endpoint(
EndpointName=model,
Body=json.dumps(payload),
ContentType="application/json",
CustomAttributes='accept_eula=true',
Accept="application/json;jpeg"
)
response_dict = json.loads(response['Body'].read().decode('utf-8'))

return ResponseImageDTO(
provider='sagemaker',
model=model,
created=int(time.time()),
data=[
{
'url': None,
'b64_json': _b64_json
} for _b64_json in response_dict['generated_images']
]
)

def _invoke_embeddings_endpoint(self, model, input):
"""
Expand Down
2 changes: 1 addition & 1 deletion sagify/llm_gateway/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MessageItem(BaseModel):


class CreateCompletionDTO(BaseModel):
provider: Optional[str]
provider: str
model: str
messages: List[MessageItem]
temperature: float
Expand Down
17 changes: 13 additions & 4 deletions sagify/llm_gateway/schemas/images.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
from typing import List
from enum import Enum
from typing import List, Optional, Union
from pydantic import BaseModel


class ResponseFormat(str, Enum):
URL = "url"
B64_JSON = "b64_json"

class CreateImageDTO(BaseModel):
provider: str
model: str
prompt: str
prompt: Union[List[str], str]
n: int
size: str
width: int
height: int
seed: Optional[int]
response_format: Optional[ResponseFormat] = 'url'


class DataItem(BaseModel):
url: str
url: Optional[str]
b64_json: Optional[str]


class ResponseImageDTO(BaseModel):
Expand Down

0 comments on commit a8b5520

Please sign in to comment.