From f4862aee345aa29cf53a618578e04485a3ec40fe Mon Sep 17 00:00:00 2001 From: Max Holland Date: Wed, 2 Oct 2024 14:22:52 +0100 Subject: [PATCH 1/3] Add AI LLM endpoint --- packages/api/src/controllers/generate.test.ts | 13 ++ packages/api/src/controllers/generate.ts | 1 + packages/api/src/schema/ai-api-schema.yaml | 124 +++++++++++++++++- packages/api/src/schema/db-schema.yaml | 2 + packages/api/src/schema/pull-ai-schema.js | 1 + 5 files changed, 140 insertions(+), 1 deletion(-) diff --git a/packages/api/src/controllers/generate.test.ts b/packages/api/src/controllers/generate.test.ts index 4b2c0f078..8a5ef8b55 100644 --- a/packages/api/src/controllers/generate.test.ts +++ b/packages/api/src/controllers/generate.test.ts @@ -227,6 +227,19 @@ describe("controllers/generate", () => { }); expect(aiGatewayCalls).toEqual({ "segment-anything-2": 1 }); }); + + it("should call the AI Gateway for generate API /llm", async () => { + const res = await client.fetch("/beta/generate/llm", { + method: "POST", + body: buildMultipartBody({}), + }); + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ + message: "success", + reqContentType: expect.stringMatching("^multipart/form-data"), + }); + expect(aiGatewayCalls).toEqual({ llm: 1 }); + }); }); describe("validates multipart schema", () => { diff --git a/packages/api/src/controllers/generate.ts b/packages/api/src/controllers/generate.ts index f9a02c85e..fb2906a37 100644 --- a/packages/api/src/controllers/generate.ts +++ b/packages/api/src/controllers/generate.ts @@ -254,5 +254,6 @@ registerGenerateHandler("image-to-video"); registerGenerateHandler("upscale"); registerGenerateHandler("audio-to-text"); registerGenerateHandler("segment-anything-2"); +registerGenerateHandler("llm"); export default app; diff --git a/packages/api/src/schema/ai-api-schema.yaml b/packages/api/src/schema/ai-api-schema.yaml index 7dbd0284d..a194ccb32 100644 --- a/packages/api/src/schema/ai-api-schema.yaml +++ b/packages/api/src/schema/ai-api-schema.yaml @@ -368,6 +368,65 @@ paths: schema: $ref: '#/components/schemas/studio-api-error' x-speakeasy-name-override: segmentAnything2 + /api/beta/generate/llm: + post: + tags: + - generate + summary: LLM + description: Generate text using a language model. + operationId: genLLM + requestBody: + content: + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Body_genLLM' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/LLMResponse' + '400': + description: Bad Request + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '401': + description: Unauthorized + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '422': + description: Validation Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPValidationError' + - $ref: '#/components/schemas/studio-api-error' + '500': + description: Internal Server Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + default: + description: Error + content: + application/json: + schema: + $ref: '#/components/schemas/studio-api-error' + x-speakeasy-name-override: llm components: schemas: APIError: @@ -414,6 +473,14 @@ components: title: Model Id description: Hugging Face model ID used for image generation. default: timbrooks/instruct-pix2pix + loras: + type: string + title: Loras + description: >- + A LoRA (Low-Rank Adaptation) model and its corresponding weight for + image generation. Example: { "latent-consistency/lcm-lora-sdxl": + 1.0, "nerijs/pixel-art-xl": 1.2}. + default: '' strength: type: number title: Strength @@ -533,6 +600,40 @@ components: - image title: Body_genImageToVideo additionalProperties: false + Body_genLLM: + properties: + prompt: + type: string + title: Prompt + model_id: + type: string + title: Model Id + default: meta-llama/Meta-Llama-3.1-8B-Instruct + system_msg: + type: string + title: System Msg + default: '' + temperature: + type: number + title: Temperature + default: 0.7 + max_tokens: + type: integer + title: Max Tokens + default: 256 + history: + type: string + title: History + default: '[]' + stream: + type: boolean + title: Stream + default: false + type: object + required: + - prompt + title: Body_genLLM + additionalProperties: false Body_genSegmentAnything2: properties: image: @@ -544,7 +645,7 @@ components: type: string title: Model Id description: Hugging Face model ID used for image generation. - default: 'facebook/sam2-hiera-large' + default: facebook/sam2-hiera-large point_coords: type: string title: Point Coords @@ -667,6 +768,19 @@ components: - images title: ImageResponse description: Response model for image generation. + LLMResponse: + properties: + response: + type: string + title: Response + tokens_used: + type: integer + title: Tokens Used + type: object + required: + - response + - tokens_used + title: LLMResponse MasksResponse: properties: masks: @@ -734,6 +848,14 @@ components: title: Model Id description: Hugging Face model ID used for image generation. default: SG161222/RealVisXL_V4.0_Lightning + loras: + type: string + title: Loras + description: >- + A LoRA (Low-Rank Adaptation) model and its corresponding weight for + image generation. Example: { "latent-consistency/lcm-lora-sdxl": + 1.0, "nerijs/pixel-art-xl": 1.2}. + default: '' prompt: type: string title: Prompt diff --git a/packages/api/src/schema/db-schema.yaml b/packages/api/src/schema/db-schema.yaml index 6388c00a1..ce8843c6d 100644 --- a/packages/api/src/schema/db-schema.yaml +++ b/packages/api/src/schema/db-schema.yaml @@ -1454,6 +1454,7 @@ components: - image-to-video - upscale - segment-anything-2 + - llm request: oneOf: - $ref: "./ai-api-schema.yaml#/components/schemas/TextToImageParams" @@ -1461,6 +1462,7 @@ components: - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genImageToVideo" - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genUpscale" - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genSegmentAnything2" + - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genLLM" statusCode: type: integer description: HTTP status code received from the AI gateway diff --git a/packages/api/src/schema/pull-ai-schema.js b/packages/api/src/schema/pull-ai-schema.js index 02645051e..d41cc63bb 100644 --- a/packages/api/src/schema/pull-ai-schema.js +++ b/packages/api/src/schema/pull-ai-schema.js @@ -12,6 +12,7 @@ export const defaultModels = { upscale: "stabilityai/stable-diffusion-x4-upscaler", "audio-to-text": "openai/whisper-large-v3", "segment-anything-2": "facebook/sam2-hiera-large", + llm: "meta-llama/Meta-Llama-3.1-8B-Instruct", }; const schemaDir = path.resolve(__dirname, "."); const aiSchemaUrl = From 110d060e96ae91552498cc7d16277e8afe84a513 Mon Sep 17 00:00:00 2001 From: Max Holland Date: Wed, 2 Oct 2024 17:08:22 +0100 Subject: [PATCH 2/3] Fix type generation --- packages/api/src/schema/compile-schemas.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/api/src/schema/compile-schemas.js b/packages/api/src/schema/compile-schemas.js index 43fe5c2c3..d0d0b2b23 100644 --- a/packages/api/src/schema/compile-schemas.js +++ b/packages/api/src/schema/compile-schemas.js @@ -46,6 +46,10 @@ function removeAllTitles(schema) { } } + if (schema.oneOf && Array.isArray(schema.oneOf)) { + schema.oneOf = schema.oneOf.map((item) => removeAllTitles(item)); + } + return schema; } From 22345dfed6319cd9331d236769ead4d0e847f16b Mon Sep 17 00:00:00 2001 From: Max Holland Date: Wed, 2 Oct 2024 20:54:21 +0100 Subject: [PATCH 3/3] fix validator lookup and test --- packages/api/src/controllers/generate.test.ts | 14 ++++++++++---- packages/api/src/controllers/generate.ts | 4 ++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/packages/api/src/controllers/generate.test.ts b/packages/api/src/controllers/generate.test.ts index 8a5ef8b55..fc5bbcf71 100644 --- a/packages/api/src/controllers/generate.test.ts +++ b/packages/api/src/controllers/generate.test.ts @@ -91,6 +91,7 @@ describe("controllers/generate", () => { "image-to-video", "upscale", "segment-anything-2", + "llm", ]; for (const api of apis) { aiGatewayServer.app.post(`/${api}`, async (req, res) => { @@ -135,13 +136,18 @@ describe("controllers/generate", () => { textFields: Record, multipartField = { name: "image", contentType: "image/png" }, ) => { + const form = buildForm(textFields); + form.append(multipartField.name, "dummy", { + contentType: multipartField.contentType, + }); + return form; + }; + + const buildForm = (textFields: Record) => { const form = new FormData(); for (const [k, v] of Object.entries(textFields)) { form.append(k, v); } - form.append(multipartField.name, "dummy", { - contentType: multipartField.contentType, - }); return form; }; @@ -231,7 +237,7 @@ describe("controllers/generate", () => { it("should call the AI Gateway for generate API /llm", async () => { const res = await client.fetch("/beta/generate/llm", { method: "POST", - body: buildMultipartBody({}), + body: buildForm({ prompt: "foo" }), }); expect(res.status).toBe(200); expect(await res.json()).toEqual({ diff --git a/packages/api/src/controllers/generate.ts b/packages/api/src/controllers/generate.ts index fb2906a37..17ea64f9b 100644 --- a/packages/api/src/controllers/generate.ts +++ b/packages/api/src/controllers/generate.ts @@ -15,6 +15,7 @@ import { BadRequestError } from "../store/errors"; import { fetchWithTimeout, kebabToCamel } from "../util"; import { experimentSubjectsOnly } from "./experiment"; import { pathJoin2 } from "./helpers"; +import validators from "../schema/validators"; const AI_GATEWAY_TIMEOUT = 10 * 60 * 1000; // 10 minutes const RATE_LIMIT_WINDOW = 60 * 1000; // 1 minute @@ -181,6 +182,9 @@ function registerGenerateHandler( if (isJSONReq) { payloadParsers = [validatePost(`${camelType}Params`)]; } else { + if (!validators[`Body_gen${camelType}`]) { + camelType = type.toUpperCase(); + } payloadParsers = [ multipart.any(), validateFormData(`Body_gen${camelType}`),