diff --git a/runner/app/main.py b/runner/app/main.py index c4dd6b8e..52e19cd2 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -56,6 +56,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.upscale import UpscalePipeline return UpscalePipeline(model_id) + case "llm-generate": + from app.pipelines.llm_generate import LLMGeneratePipeline + return LLMGeneratePipeline(model_id) case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}" @@ -88,6 +91,10 @@ def load_route(pipeline: str) -> any: from app.routes import upscale return upscale.router + case "llm-generate": + from app.routes import llm_generate + + return llm_generate.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/pipelines/llm_generate.py b/runner/app/pipelines/llm_generate.py new file mode 100644 index 00000000..2b72c813 --- /dev/null +++ b/runner/app/pipelines/llm_generate.py @@ -0,0 +1,97 @@ +import logging +import os +from typing import Dict, Any, Optional + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig +from app.pipelines.base import Pipeline +from app.pipelines.utils import get_model_dir, get_torch_device +from huggingface_hub import file_download, hf_hub_download + +logger = logging.getLogger(__name__) + + +class LLMGeneratePipeline(Pipeline): + def __init__(self, model_id: str): + self.model_id = model_id + kwargs = { + "cache_dir": get_model_dir() + } + self.device = get_torch_device() + folder_name = file_download.repo_folder_name( + repo_id=model_id, repo_type="model" + ) + folder_path = os.path.join(get_model_dir(), folder_name) + + # Check for fp16 variant + has_fp16_variant = any( + ".fp16.safetensors" in fname + for _, _, files in os.walk(folder_path) + for fname in files + ) + if self.device != "cpu" and has_fp16_variant: + logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id) + kwargs["torch_dtype"] = torch.float16 + kwargs["variant"] = "fp16" + + # Load tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) + + # Load model + self.model = AutoModelForCausalLM.from_pretrained( + model_id, **kwargs).to(self.device) + + # Set up generation config + self.generation_config = self.model.generation_config + + # Optional: Add optimizations + sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" + if sfast_enabled: + logger.info( + "LLMGeneratePipeline will be dynamically compiled with stable-fast for %s", + model_id, + ) + from app.pipelines.optim.sfast import compile_model + self.model = compile_model(self.model) + + def __call__(self, prompt: str, system_msg: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, **kwargs) -> Dict[str, Any]: + if system_msg: + input_text = f"{system_msg}\n\n{prompt}" + else: + input_text = prompt + + input_ids = self.tokenizer.encode( + input_text, return_tensors="pt").to(self.device) + + # Update generation config + gen_kwargs = {} + if temperature is not None: + gen_kwargs['temperature'] = temperature + if max_tokens is not None: + gen_kwargs['max_new_tokens'] = max_tokens + + # Merge generation config with provided kwargs + gen_kwargs = {**self.generation_config.to_dict(), **gen_kwargs, **kwargs} + + # Generate response + with torch.no_grad(): + output = self.model.generate( + input_ids, + **gen_kwargs + ) + + # Decode the response + response = self.tokenizer.decode(output[0], skip_special_tokens=True) + + # Calculate tokens used + tokens_used = len(output[0]) + + return { + "response": response.strip(), + "tokens_used": tokens_used + } + + def __str__(self) -> str: + return f"LLMPipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 85f37cdb..4bada2f0 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -22,6 +22,7 @@ StableDiffusionXLPipeline, UNet2DConditionModel, ) +from diffusers.models import AutoencoderKL from huggingface_hub import file_download, hf_hub_download from safetensors.torch import load_file @@ -34,6 +35,7 @@ class ModelName(Enum): SDXL_LIGHTNING = "ByteDance/SDXL-Lightning" SD3_MEDIUM = "stabilityai/stable-diffusion-3-medium-diffusers" + REALISTIC_VISION_V6 = "SG161222/Realistic_Vision_V6.0_B1_noVAE" @classmethod def list(cls): @@ -71,6 +73,11 @@ def __init__(self, model_id: str): if os.environ.get("BFLOAT16"): logger.info("TextToImagePipeline using bfloat16 precision for %s", model_id) kwargs["torch_dtype"] = torch.bfloat16 + + # Load VAE for specific models. + if ModelName.REALISTIC_VISION_V6.value in model_id: + vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema") + kwargs["vae"] = vae # Special case SDXL-Lightning because the unet for SDXL needs to be swapped if ModelName.SDXL_LIGHTNING.value in model_id: diff --git a/runner/app/routes/llm_generate.py b/runner/app/routes/llm_generate.py new file mode 100644 index 00000000..d43120d3 --- /dev/null +++ b/runner/app/routes/llm_generate.py @@ -0,0 +1,64 @@ +import logging +import os +from typing import Annotated, Optional +from fastapi import APIRouter, Depends, Form, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.routes.util import HTTPError, LlmResponse, TextResponse, http_error + +router = APIRouter() + +logger = logging.getLogger(__name__) + +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + +@router.post("/llm-generate", response_model=LlmResponse, responses=RESPONSES) +@router.post("/llm-generate/", response_model=LlmResponse, responses=RESPONSES, include_in_schema=False) +async def llm_generate( + prompt: Annotated[str, Form()], + model_id: Annotated[str, Form()] = "", + system_msg: Annotated[str, Form()] = None, + temperature: Annotated[float, Form()] = None, + max_tokens: Annotated[int, Form()] = None, + pipeline: Pipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + + if model_id != "" and model_id != pipeline.model_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + f"pipeline configured with {pipeline.model_id} but called with " + f"{model_id}" + ), + ) + + try: + result = pipeline( + prompt=prompt, + system_msg=system_msg, + temperature=temperature, + max_tokens=max_tokens + ) + return JSONResponse(content=result) + except Exception as e: + logger.error(f"LLM processing error: {str(e)}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=http_error("Internal server error during LLM processing."), + ) diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py index 96736305..e0fc914a 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/util.py @@ -34,6 +34,11 @@ class TextResponse(BaseModel): chunks: List[chunk] +class LlmResponse(BaseModel): + response: str + tokens_used: int + + class APIError(BaseModel): msg: str diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 1528a60b..c47fdade 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -56,10 +56,15 @@ function download_all_models() { huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download stabilityai/stable-diffusion-3-medium-diffusers --include "*.fp16*.safetensors" "*.model" "*.json" "*.txt" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"} + huggingface-cli download SG161222/Realistic_Vision_V6.0_B1_noVAE --include "*.fp16.safetensors" "*.json" "*.txt" "*.bin" --exclude ".onnx" ".onnx_data" --cache-dir models # Download image-to-video models. huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models + # Download LLM models (Warning: large model size) + huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models + + #Download frame-interpolation model. wget -O models/film_net_fp16.pt https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.2/film_net_fp16.pt } diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index 198102db..b5f86f04 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -13,6 +13,7 @@ text_to_image, frame_interpolation, upscale, + llm_generate ) from fastapi.openapi.utils import get_openapi @@ -85,6 +86,7 @@ def write_openapi(fname, entrypoint="runner"): app.include_router(image_to_image.router) app.include_router(image_to_video.router) app.include_router(audio_to_text.router) + app.include_router(llm_generate.router) app.include_router(frame_interpolation.router) app.include_router(upscale.router) diff --git a/runner/openapi.json b/runner/openapi.json index 2e2bb4e7..17521c10 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -477,6 +477,79 @@ } ] } + }, + "/llm-generate": { + "post": { + "summary": "Llm Generate", + "operationId": "llm_generate", + "requestBody": { + "content": { + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Body_llm_generate_llm_generate_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LlmResponse" + } + } + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } } }, "components": { @@ -667,6 +740,36 @@ ], "title": "Body_image_to_video_image_to_video_post" }, + "Body_llm_generate_llm_generate_post": { + "properties": { + "prompt": { + "type": "string", + "title": "Prompt" + }, + "model_id": { + "type": "string", + "title": "Model Id", + "default": "" + }, + "system_msg": { + "type": "string", + "title": "System Msg" + }, + "temperature": { + "type": "number", + "title": "Temperature" + }, + "max_tokens": { + "type": "integer", + "title": "Max Tokens" + } + }, + "type": "object", + "required": [ + "prompt" + ], + "title": "Body_llm_generate_llm_generate_post" + }, "Body_upscale_upscale_post": { "properties": { "prompt": { @@ -757,6 +860,24 @@ ], "title": "ImageResponse" }, + "LlmResponse": { + "properties": { + "response": { + "type": "string", + "title": "Response" + }, + "tokens_used": { + "type": "integer", + "title": "Tokens Used" + } + }, + "type": "object", + "required": [ + "response", + "tokens_used" + ], + "title": "LlmResponse" + }, "Media": { "properties": { "url": { diff --git a/runner/requirements.txt b/runner/requirements.txt index 82877c24..d877048f 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -12,7 +12,7 @@ torchvision --index-url https://download.pytorch.org/whl/cu121 torchaudio --index-url https://download.pytorch.org/whl/cu121 huggingface_hub==0.23.2 xformers==0.0.23 -triton>=2.1.0 +triton>=0.1.0 peft==0.11.1 deepcache==0.1.1 safetensors==0.4.3 diff --git a/worker/docker.go b/worker/docker.go index 8d7f97e0..ce510493 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -35,6 +35,7 @@ var containerHostPorts = map[string]string{ "image-to-video": "8002", "upscale": "8003", "audio-to-text": "8004", + "llm": "8005", } type DockerManager struct { diff --git a/worker/multipart.go b/worker/multipart.go index 865b9114..7a87bcae 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -240,3 +240,41 @@ func NewAudioToTextMultipartWriter(w io.Writer, req AudioToTextMultipartRequestB return mw, nil } + +func NewLlmGenerateMultipartWriter(w io.Writer, req BodyLlmGenerateLlmGeneratePost) (*multipart.Writer, error) { + mw := multipart.NewWriter(w) + + if err := mw.WriteField("prompt", req.Prompt); err != nil { + return nil, fmt.Errorf("failed to write prompt field: %w", err) + } + + if req.ModelId != nil { + if err := mw.WriteField("model_id", *req.ModelId); err != nil { + return nil, fmt.Errorf("failed to write model_id field: %w", err) + } + } + + if req.SystemMsg != nil { + if err := mw.WriteField("system_msg", *req.SystemMsg); err != nil { + return nil, fmt.Errorf("failed to write system_msg field: %w", err) + } + } + + if req.Temperature != nil { + if err := mw.WriteField("temperature", fmt.Sprintf("%f", *req.Temperature)); err != nil { + return nil, fmt.Errorf("failed to write temperature field: %w", err) + } + } + + if req.MaxTokens != nil { + if err := mw.WriteField("max_tokens", strconv.Itoa(*req.MaxTokens)); err != nil { + return nil, fmt.Errorf("failed to write max_tokens field: %w", err) + } + } + + if err := mw.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + return mw, nil +} diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 788a7e82..4c9ad354 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -75,6 +75,15 @@ type BodyImageToVideoImageToVideoPost struct { Width *int `json:"width,omitempty"` } +// BodyLlmGenerateLlmGeneratePost defines model for Body_llm_generate_llm_generate_post. +type BodyLlmGenerateLlmGeneratePost struct { + MaxTokens *int `json:"max_tokens,omitempty"` + ModelId *string `json:"model_id,omitempty"` + Prompt string `json:"prompt"` + SystemMsg *string `json:"system_msg,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` +} + // BodyUpscaleUpscalePost defines model for Body_upscale_upscale_post. type BodyUpscaleUpscalePost struct { Image openapi_types.File `json:"image"` @@ -105,6 +114,12 @@ type ImageResponse struct { Images []Media `json:"images"` } +// LlmResponse defines model for LlmResponse. +type LlmResponse struct { + Response string `json:"response"` + TokensUsed int `json:"tokens_used"` +} + // Media defines model for Media. type Media struct { Nsfw bool `json:"nsfw"` @@ -173,6 +188,9 @@ type ImageToImageMultipartRequestBody = BodyImageToImageImageToImagePost // ImageToVideoMultipartRequestBody defines body for ImageToVideo for multipart/form-data ContentType. type ImageToVideoMultipartRequestBody = BodyImageToVideoImageToVideoPost +// LlmGenerateFormdataRequestBody defines body for LlmGenerate for application/x-www-form-urlencoded ContentType. +type LlmGenerateFormdataRequestBody = BodyLlmGenerateLlmGeneratePost + // TextToImageJSONRequestBody defines body for TextToImage for application/json ContentType. type TextToImageJSONRequestBody = TextToImageParams @@ -329,6 +347,11 @@ type ClientInterface interface { // ImageToVideoWithBody request with any body ImageToVideoWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // LlmGenerateWithBody request with any body + LlmGenerateWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + LlmGenerateWithFormdataBody(ctx context.Context, body LlmGenerateFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + // TextToImageWithBody request with any body TextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -398,6 +421,30 @@ func (c *Client) ImageToVideoWithBody(ctx context.Context, contentType string, b return c.Client.Do(req) } +func (c *Client) LlmGenerateWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewLlmGenerateRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) LlmGenerateWithFormdataBody(ctx context.Context, body LlmGenerateFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewLlmGenerateRequestWithFormdataBody(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) TextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewTextToImageRequestWithBody(c.Server, contentType, body) if err != nil { @@ -577,6 +624,46 @@ func NewImageToVideoRequestWithBody(server string, contentType string, body io.R return req, nil } +// NewLlmGenerateRequestWithFormdataBody calls the generic LlmGenerate builder with application/x-www-form-urlencoded body +func NewLlmGenerateRequestWithFormdataBody(server string, body LlmGenerateFormdataRequestBody) (*http.Request, error) { + var bodyReader io.Reader + bodyStr, err := runtime.MarshalForm(body, nil) + if err != nil { + return nil, err + } + bodyReader = strings.NewReader(bodyStr.Encode()) + return NewLlmGenerateRequestWithBody(server, "application/x-www-form-urlencoded", bodyReader) +} + +// NewLlmGenerateRequestWithBody generates requests for LlmGenerate with any type of body +func NewLlmGenerateRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/llm-generate") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + // NewTextToImageRequest calls the generic TextToImage builder with application/json body func NewTextToImageRequest(server string, body TextToImageJSONRequestBody) (*http.Request, error) { var bodyReader io.Reader @@ -704,6 +791,11 @@ type ClientWithResponsesInterface interface { // ImageToVideoWithBodyWithResponse request with any body ImageToVideoWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ImageToVideoResponse, error) + // LlmGenerateWithBodyWithResponse request with any body + LlmGenerateWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*LlmGenerateResponse, error) + + LlmGenerateWithFormdataBodyWithResponse(ctx context.Context, body LlmGenerateFormdataRequestBody, reqEditors ...RequestEditorFn) (*LlmGenerateResponse, error) + // TextToImageWithBodyWithResponse request with any body TextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToImageResponse, error) @@ -840,6 +932,32 @@ func (r ImageToVideoResponse) StatusCode() int { return 0 } +type LlmGenerateResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *LlmResponse + JSON400 *HTTPError + JSON401 *HTTPError + JSON422 *HTTPValidationError + JSON500 *HTTPError +} + +// Status returns HTTPResponse.Status +func (r LlmGenerateResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r LlmGenerateResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type TextToImageResponse struct { Body []byte HTTPResponse *http.Response @@ -937,6 +1055,23 @@ func (c *ClientWithResponses) ImageToVideoWithBodyWithResponse(ctx context.Conte return ParseImageToVideoResponse(rsp) } +// LlmGenerateWithBodyWithResponse request with arbitrary body returning *LlmGenerateResponse +func (c *ClientWithResponses) LlmGenerateWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*LlmGenerateResponse, error) { + rsp, err := c.LlmGenerateWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseLlmGenerateResponse(rsp) +} + +func (c *ClientWithResponses) LlmGenerateWithFormdataBodyWithResponse(ctx context.Context, body LlmGenerateFormdataRequestBody, reqEditors ...RequestEditorFn) (*LlmGenerateResponse, error) { + rsp, err := c.LlmGenerateWithFormdataBody(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseLlmGenerateResponse(rsp) +} + // TextToImageWithBodyWithResponse request with arbitrary body returning *TextToImageResponse func (c *ClientWithResponses) TextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToImageResponse, error) { rsp, err := c.TextToImageWithBody(ctx, contentType, body, reqEditors...) @@ -1212,6 +1347,60 @@ func ParseImageToVideoResponse(rsp *http.Response) (*ImageToVideoResponse, error return response, nil } +// ParseLlmGenerateResponse parses an HTTP response from a LlmGenerateWithResponse call +func ParseLlmGenerateResponse(rsp *http.Response) (*LlmGenerateResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &LlmGenerateResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest LlmResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + } + + return response, nil +} + // ParseTextToImageResponse parses an HTTP response from a TextToImageWithResponse call func ParseTextToImageResponse(rsp *http.Response) (*TextToImageResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -1337,6 +1526,9 @@ type ServerInterface interface { // Image To Video // (POST /image-to-video) ImageToVideo(w http.ResponseWriter, r *http.Request) + // Llm Generate + // (POST /llm-generate) + LlmGenerate(w http.ResponseWriter, r *http.Request) // Text To Image // (POST /text-to-image) TextToImage(w http.ResponseWriter, r *http.Request) @@ -1379,6 +1571,12 @@ func (_ Unimplemented) ImageToVideo(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotImplemented) } +// Llm Generate +// (POST /llm-generate) +func (_ Unimplemented) LlmGenerate(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Text To Image // (POST /text-to-image) func (_ Unimplemented) TextToImage(w http.ResponseWriter, r *http.Request) { @@ -1483,6 +1681,23 @@ func (siw *ServerInterfaceWrapper) ImageToVideo(w http.ResponseWriter, r *http.R handler.ServeHTTP(w, r.WithContext(ctx)) } +// LlmGenerate operation middleware +func (siw *ServerInterfaceWrapper) LlmGenerate(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.LlmGenerate(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // TextToImage operation middleware func (siw *ServerInterfaceWrapper) TextToImage(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -1645,6 +1860,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/image-to-video", wrapper.ImageToVideo) }) + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/llm-generate", wrapper.LlmGenerate) + }) r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/text-to-image", wrapper.TextToImage) }) @@ -1658,33 +1876,34 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xZWW/bOhb+KwRnHp14aTMZ+C1Jt2C6BI3beSgCg5GObbYSqUtSaX0D//cLHsoSJVOR", - "jTS+QK6fvOgs31m+w0X3NJJpJgUIo+n4nupoASnDr2dXl6+Vksp+z5TMQBkO+CTVc/thuEmAjukHPac9", - "apaZ/aGN4mJOV6seVfBHzhXEdPwNVW56pUppu9STt98hMnTVo+cyXk5ZHnM5NXJq4Jdp/MqkNpugUMZ+", - "mUmVMkPH9JYLppbU84oiG1B7NJUxJFMeW/UYZixPrL6n+cEKkMu4M06Hwot0u2ja0jBTLIUpFwZUJhNm", - "uBTB/8Ip4Smbw/DhnFw6mUBSUHu0hfaoVXsac9WaU9Qlr7gKqtvwXKi6ZmHkGbAy5I2TKW1YzTmoZll3", - "rmSp3CzmTjVpK6xLj5HFl8bPcDnnOY+ZiGCqI2bheFk5PT6pUL4t5Mg1ypUQRJ7eusSgly0q217YB7AM", - "fSyuyN2IHsHAHhUwZ4bfwTRTMs1Mq42PhRy5cnIhU3nqaqCnGaiQwaFnL08JBqjJFagNq14nolkxAwWY", - "MwNZvauHg0HD7FqYXKNwyGgFbq3ZHpdmMzDLabSA6EfNs1E5VK6vUYxcoFhp5lbKBJhAOwA1Ol3b3yFw", - "2igQc7OoORsc/9fztZbYaIcGE7N1VK5tm3zcgkqdLLzjMcjmzzALZ43S/aeC86alUAvg80W9i05OPb13", - "7nlI9TFMfRSnUolD7DaPfoBpGhmOTn0rVpKco2TNmk8AyTVMWT6ftjTGwBvsH60wOcvnpL1Hujk1Otmd", - "UnunyU8eN1IxHIxeVp7+j883NRsU6WBGe3u3MSPPcLCXnw9sMP6O7uyq/enJsxqnuw3EYO0ChX43mVy1", - "7PBjMIwn9tu/FczomP6rX50T+sUhoV/u4psAC3UPWOWrBchXlvAYN06dkLiBVHdha9pbVVheOUslEKYU", - "W2IMPtqmgRBuYIlZXKyboI5XG2byelfST/+j/vqHAqF9aLUwVA4C/pFbn0FnUmhoYafeOmMfIObMz5Pb", - "2oTytDF6tF/rOqwAbudpA6/Qs58+GT7a34+arrlKfLkvKunc9ucoo51FRORF5oAHIprAL9NeiGiRix/b", - "FwLF/UJcOP1mIXrUHiD9AC2MzgiNEypAedHVgmgJciKxuldMMRfIUx1Rqj3TFrukf/jp4eS5HR7KXdGO", - "26AiqEZP13s20Nida08ioxp7mVh+mtHxt/uNXN1vQLzxiPxeRugmQOXmnRpo3bJxcn9UooiZTOy/XdS3", - "cThXhaSXqS3Wu69239g+5qrbmjJROy48zfG2Plc1rnjCC1Hh3g+phjcQkJu0G4FsN1atnxS0YWnmh+rh", - "npTPO6AbX9A684JwGDfAI52iXHGzvLZ5dMjtxuUcmAJVXuYiB91fpZGFMRldrfCebSYdpXWkeIbNOaZn", - "grAsS7jrVmIkUbkgZ5ck4xkkXLhirJua30EGoOzzz7kQ6OgOlHa2BsfD44HNlsxAsIzTMX2Bf/VoxswC", - "YffxSvTIyKN16tfnDVsWBHEZry9wJ7Koh80gaGP3vLjKSmFAoFaaJ4ZnTJm+PZgcxcyw6nK7qx23u7Fd", - "1WtoJyH+4ZoNoxoNBg1cXlL737VNz7agamsz+q5X7DqPItB6liekEuvRl78RQrWFD/g/ZzH57Orh/A73", - "4/eLYLlZSMX/hBgdD1/sx3ERLHktDDdLMpGSvGdq7rI+Gv1WEBtnmU04lQgpzzsn+yo+XsQLlpBrUHeg", - "SHUoXI8oXCv94fTtZnXTozpPU6aWa2aTiSTIbavaD9ypt08GXCIua7JPOyB2eguw52FRP4AdpkX7tDgQ", - "dVeiItFInWlI1wXeVeAhEAIEdVcZ9Am73r8s2bbnV35oBUSMBk9xdkNSXnGG5w5SrThgPPHE2eI9x2HO", - "HObMM5kz7sXxRLorkgYp8QVGJynx+LcvUra/YtkzKeuH3gMpD6R8AlI6aiEp7ZF4i4XSu4h7kJKPOyLX", - "r/oOy+GBec+Eeba5G6th8Xq3nXJfCoGnXQGDb5sPzDsw75kwb82ildOyZjQq1T2Vt+AXicxjciHTNBfc", - "LMlbZuAnW9LibTTevetxvx8rYOnR3D09Tgr148iq09XN6q8AAAD///BmUXaaLQAA", + "H4sIAAAAAAAC/+xZW2/bOhL+KwR3H53YSZvNwm9JttsGm7ZB7XQfikBgpLHNRiK1vCTxBv7vByR1oW6V", + "jDQ+QI6fbFHDmW/I+YYz1DMOeZJyBkxJPH3GMlxBQuzfs+vLD0JwYf6ngqcgFAX7JpFL86OoigFP8We5", + "xCOs1ql5kEpQtsSbzQgL+J+mAiI8/WGn3I6KKYXuYh6/+wmhwpsRPufROiA6ojxQPFDwpGpPKZeqCcrK", + "mD8LLhKi8BTfUUbEGntWrUgD6ggnPII4oJGZHsGC6NjM92Z+NgLoMur106HwPB3mTdcy0IQswYi6P7XH", + "9oVYahoRFkIgQ2IgeC6dHp6UyD5mcmhm5QoITCd3IAwEa+XXS3ppRVqW1CH8BZYjH4tVg/oRvWCjRpjB", + "kij6AEEqeJKqTh1fMjl07eTaVOnE7YEMUhBtCo88fTpB1kGJrkE0tFKmYOncs2rZAgTYNVOQyqrSyaSm", + "NhdGMyvcprQEl8/s9kuSBah1EK4gvK9YVkJDaXpmxdCFFSvU3HEeA2FWD0DkW5yZ5zZwUglgS7WqGJsc", + "/tOzlUs0wqFGvTT3yoVtnYMDqNTLwgcaAa8/trNwUdu6f5Rw/t2xUSugy1U1ik5OvXmf3Pu2qS9h6os4", + "lXBFOQvudHgPqq7k6PjU12Ik0bmVrGjzCcCphIDoZdARGJNjjwBGGJ3pJeqOkX5OHZ9sT6md0+SRRrWl", + "OJocvy8t/de+b86sUaSHGd3h3cWMOE6CJTAQREH1oZ0VCXkKFL8HJisFBHlCczfa5vyLAnSr7LeWCpKg", + "Vt7M7ChqrXJGWEGSGo+1AH/S3BsemLjq29Kztl1bolN71ha/7RvxpyWMPjqenrypE267M6p171o2+tN8", + "ft1Rm0egCI3Nv78LWOAp/tu4rPDHWXk/LurvOsBsugestNUB5DuJaURMcu+FRBUksg9bXd+mxPIvp6kA", + "QoQga+uDj7auoA03kFitLvIgqOKViihdjUr89T/YL0msQFsvUJ7VpYEW+5Zb30CmnEnoYKccvGKfIaLE", + "XydXbbatU+M0kP5eV2G14L6Kk27UwnuTa2wq81KnTfmBllUmuZMA3cghhBKefk+d55MPucUjt3YNX5hc", + "PPqgvpjnFx3hWsS+3I2Ie5tJbWWk02gReX454C0ezeFJdW9SuNLsfnhoWXE/tC7c/HpomYPwSVVPwCfV", + "66FyQhkoz7uKEx1OzrmN12siiHPktfrgsjAfUIr/xVvUk7fWoRal95a1drOqa8ZsS2D3nqYxDyvsJWz9", + "dYGnP54ba/XcgHjrEfmKh9ZMC5Xr93sgZUcp6AZKUYsZzc1oH/WNH85UJumt1IAT/LtpTrrT3EKQpHaC", + "bnmU1tNb3rw7xT1Ha2bed6mCt8Uhl2kbjgxLq8ZOAlKRJPVd9XDPi/c90JUvaIx5TjiMDfCWTqEWVK1n", + "Zh0dclOKnQMRIIqLZctBN1QoWSmV4o3RQdmCO0rLUNDUBucUnzFE0jSmLlqR4khohs4uUUpTiClzm5EH", + "NX2AFECY9980Y9bQAwjpdE0Ojw4nZrV4CoykFE/xOzs0wilRKwt7bK9nDxQ/yJc+76C47egoZ5dRfpk8", + "59l+mBUEqUwVb09ZzhQwOyvRsaIpEWpsWq2DiChSXrT3heOw2+NNdQ9NJrQDLtisV8eTSQ2Xt6jjn9Is", + "z1BQlbPZ2q7u2EyHIUi50DEqxUb4/W+EUDYlLfbPSYS+uf1wdo92Y/eGEa1WXND/Q2QNH73bjeHMWfSB", + "KarWaM45uiJi6Vb9+Pi3gmh0Z004pQgqOriTXW3+JVMgGInRDMQDCFS2uXmKsmeln5x+3G5uR1jqJCFi", + "nTMbzTmy3DZTxyvbztmqElpygev28Ctyzu8nh1Ju4zuVQbTe2LLQZLjiFqg9xdlSJatYXjnHDbid33GW", + "q/bC+zTXneb2GWbbDOM+d86567lqpLTX7r2ktPXkrkjZ/WFgx6SsVtF7Uu5J+QqkdNSypIzj5CD/7NJN", + "yas4+ZgL/YqRvu9PB4+PjweWmVrEwEIeuQuJLfjZ84Vox9z0L1r3zNwz8/cx8ypOUEEwy0vT+w4oYL0b", + "t8HE3L4Xrt7p7cvUPe/eCO9McNeq1OzLdDflbjKB161MWz+U75m3Z94bYV7Ooo2bZdRIO6lqqbjuvoi5", + "jtAFTxLNqFqjj0TBI1nj7LOzvWSX0/E4EkCSg6V7exhn0w9DMx1vbjd/BAAA///6mbgGDy4AAA==", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index 7877f6dd..0dadceb9 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -304,6 +304,54 @@ func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartReques return resp.JSON200, nil } +func (w *Worker) LlmGenerate(ctx context.Context, req BodyLlmGenerateLlmGeneratePost) (*LlmResponse, error) { + c, err := w.borrowContainer(ctx, "llm-generate", *req.ModelId) + if err != nil { + return nil, err + } + defer w.returnContainer(c) + + var buf bytes.Buffer + mw, err := NewLlmGenerateMultipartWriter(&buf, req) + if err != nil { + return nil, err + } + + resp, err := c.Client.LlmGenerateWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } + + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("llm-generate container returned 400", slog.String("err", string(val))) + return nil, errors.New("llm-generate container returned 400") + } + + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) + if err != nil { + return nil, err + } + slog.Error("llm-generate container returned 401", slog.String("err", string(val))) + return nil, errors.New("llm-generate container returned 401") + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("llm-generate container returned 500", slog.String("err", string(val))) + return nil, errors.New("llm-generate container returned 500") + } + + return resp.JSON200, nil +} + func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error { if endpoint.URL == "" { return w.manager.Warm(ctx, pipeline, modelID, optimizationFlags)