diff --git a/.prettierignore b/.prettierignore index 90c314b65..abad90c6e 100644 --- a/.prettierignore +++ b/.prettierignore @@ -17,3 +17,5 @@ packages/api/src/schema/schema.yaml packages/www/static-build packages/www/static-build-app packages/api/dist-esbuild/api.js +# auto-generated from AI Gateway schema +packages/api/src/schema/ai-api-schema.yaml diff --git a/package.json b/package.json index 01fa980fe..cd0721705 100644 --- a/package.json +++ b/package.json @@ -26,6 +26,8 @@ "release:dry-run": "lerna publish --exact --skip-git --skip-npm --cd-version prerelease --conventional-commits --yes", "release:alpha": "lerna publish --exact --cd-version prerelease --conventional-commits", "test": "cd packages/api && yarn run test", + "compile-schemas": "cd packages/api && yarn run compile-schemas", + "pull-ai-schema": "cd packages/api && yarn run pull-ai-schema", "dev": "touch .env.local && cp .env.local packages/www && lerna run --stream --no-sort --concurrency=999 dev", "updated": "lerna updated --json", "prettier:base": "prettier '**/*.{ts,js,css,html,md,tsx,mdx,yaml,yml}'", diff --git a/packages/api/package.json b/packages/api/package.json index 05243e513..0d9b1a5ce 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -21,7 +21,8 @@ "prepare:redoc": "redoc-cli bundle --cdn -o docs/index.html src/schema/schema.yaml", "prepare:type-check": "tsc --pretty --noEmit", "prepare": "run-s compile-schemas && run-p \"prepare:**\"", - "compile-schemas": "node -r esm src/compile-schemas.js", + "compile-schemas": "node -r esm src/schema/compile-schemas.js", + "pull-ai-schema": "node -r esm src/schema/pull-ai-schema.js", "dev-server": "run-s compile-schemas && node dist/cli.js", "redoc": "nodemon -w src/schema/schema.yaml -x npm run prepare:redoc", "siserver": "nodemon -w dist -x node -r esm dist/stream-info-service.js -e js,yaml", diff --git a/packages/api/src/controllers/generate.test.ts b/packages/api/src/controllers/generate.test.ts index 4e3e49d3a..4b2c0f078 100644 --- a/packages/api/src/controllers/generate.test.ts +++ b/packages/api/src/controllers/generate.test.ts @@ -90,6 +90,7 @@ describe("controllers/generate", () => { "image-to-image", "image-to-video", "upscale", + "segment-anything-2", ]; for (const api of apis) { aiGatewayServer.app.post(`/${api}`, async (req, res) => { @@ -213,6 +214,19 @@ describe("controllers/generate", () => { }); expect(aiGatewayCalls).toEqual({ upscale: 1 }); }); + + it("should call the AI Gateway for generate API /segment-anything-2", async () => { + const res = await client.fetch("/beta/generate/segment-anything-2", { + method: "POST", + body: buildMultipartBody({}), + }); + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ + message: "success", + reqContentType: expect.stringMatching("^multipart/form-data"), + }); + expect(aiGatewayCalls).toEqual({ "segment-anything-2": 1 }); + }); }); describe("validates multipart schema", () => { diff --git a/packages/api/src/controllers/generate.ts b/packages/api/src/controllers/generate.ts index 43b875664..f9a02c85e 100644 --- a/packages/api/src/controllers/generate.ts +++ b/packages/api/src/controllers/generate.ts @@ -8,10 +8,11 @@ import sql from "sql-template-strings"; import { v4 as uuid } from "uuid"; import logger from "../logger"; import { authorizer, validateFormData, validatePost } from "../middleware"; +import { defaultModels } from "../schema/pull-ai-schema"; import { AiGenerateLog } from "../schema/types"; import { db } from "../store"; import { BadRequestError } from "../store/errors"; -import { fetchWithTimeout } from "../util"; +import { fetchWithTimeout, kebabToCamel } from "../util"; import { experimentSubjectsOnly } from "./experiment"; import { pathJoin2 } from "./helpers"; @@ -170,13 +171,24 @@ function logAiGenerateRequest( function registerGenerateHandler( type: AiGenerateType, - defaultModel: string, isJSONReq = false, // multipart by default ): RequestHandler { const path = `/${type}`; - const payloadParsers = isJSONReq - ? [validatePost(`${type}-payload`)] - : [multipart.any(), validateFormData(`${type}-payload`)]; + + let payloadParsers: RequestHandler[]; + let camelType = kebabToCamel(type); + camelType = camelType[0].toUpperCase() + camelType.slice(1); + if (isJSONReq) { + payloadParsers = [validatePost(`${camelType}Params`)]; + } else { + payloadParsers = [ + multipart.any(), + validateFormData(`Body_gen${camelType}`), + ]; + } + + const defaultModel = defaultModels[type]; + return app.post( path, authorizer({}), @@ -236,17 +248,11 @@ function registerGenerateHandler( ); } -registerGenerateHandler( - "text-to-image", - "SG161222/RealVisXL_V4.0_Lightning", - true, -); -registerGenerateHandler("image-to-image", "timbrooks/instruct-pix2pix"); -registerGenerateHandler( - "image-to-video", - "stabilityai/stable-video-diffusion-img2vid-xt-1-1", -); -registerGenerateHandler("upscale", "stabilityai/stable-diffusion-x4-upscaler"); -registerGenerateHandler("audio-to-text", "openai/whisper-large-v3"); +registerGenerateHandler("text-to-image", true); +registerGenerateHandler("image-to-image"); +registerGenerateHandler("image-to-video"); +registerGenerateHandler("upscale"); +registerGenerateHandler("audio-to-text"); +registerGenerateHandler("segment-anything-2"); export default app; diff --git a/packages/api/src/schema/README.md b/packages/api/src/schema/README.md index 09aa58e7d..bfc62c0bb 100644 --- a/packages/api/src/schema/README.md +++ b/packages/api/src/schema/README.md @@ -1,14 +1,21 @@ # Studio API Schema -Our API schema is generated from the 2 YAML files in this repository: - -- `api-schema.yaml` - the schema file for our public API -- `db-schema.yaml` - the schema file for internal fields we use in our code - -These 2 files are deep merged on a key-by-key basis to generate the final schema -file, with `api-schema.yaml` going first (so `db-schema` can override values). -It is recursive, so if you want to set only 1 key in a nested object you can set -only that and all the other fields in the objects will be left intact. +Our API schema is generated from the 3 YAML files in this repository: + +- `api-schema.yaml` - The schema file for our public API. This is the base and + used by the code and docs/SDKs +- `ai-api-schema.yaml` - The schema for the AI Gateway APIs. This is also used + by code and docs/SDKs, but is kept separate since it's pulled from + `livepeer/ai-worker`. +- `db-schema.yaml` - The schema file for internal fields we use in our code. + This is used by the code, but not for docs/SDK generation since it contains + internal abstractions. + +These files are deep merged on a key-by-key basis to generate the final schema +file, in the order specified above (the later can override the previous ones). +It is recursive, so if you want to set only 1 key in a nested object you can +specify only the nested field and all the other fields in the objects path will +be left intact. e.g. `{a:{b:{c:d:"hello"}}}` will set only the `d` field in the `c` nested obj. @@ -26,6 +33,10 @@ possible reasons to use `db-schema` instead: returned objects in our code (e.g. `password`, `createdByTokenId`) - Deprecated fields we don't want anyone using (e.g. `wowza`, `detection`) +The `ai-api-schema.yaml` file should never be edited manually. Instead, run +`yarn pull-ai-schema` to update it from the source of truth +(`livepeer/ai-worker`). + ## Outputs The schema files are used to generate the following files: @@ -39,3 +50,14 @@ The schema files are used to generate the following files: our API code to validate request payloads (`middleware/validators.js`) Check `compile-schemas.js` for more details on the whole process. + +## AI APIs + +The flow for the AI Gateway schemas is: + +- When there are changes to the upstream AI Gateway schema, a developer can run + `yarn pull-ai-schema` to update the version in the repository with it. +- The `ai-api-schema.yaml` file is merged into the code abstractions in the + `compile-schemas.js` script above. +- The `ai-api-schema.yaml` file is also used on the automatic SDK and docs + generation to include the AI APIs. diff --git a/packages/api/src/schema/ai-api-schema.yaml b/packages/api/src/schema/ai-api-schema.yaml new file mode 100644 index 000000000..39dbb90f8 --- /dev/null +++ b/packages/api/src/schema/ai-api-schema.yaml @@ -0,0 +1,869 @@ +openapi: 3.1.0 +paths: + /api/beta/generate/text-to-image: + post: + tags: + - generate + summary: Text To Image + description: Generate images from text prompts. + operationId: genTextToImage + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/TextToImageParams' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ImageResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '401': + description: Unauthorized + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '422': + description: Validation Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPValidationError' + - $ref: '#/components/schemas/studio-api-error' + '500': + description: Internal Server Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + default: + description: Error + content: + application/json: + schema: + $ref: '#/components/schemas/studio-api-error' + security: + - HTTPBearer: [] + x-speakeasy-name-override: textToImage + /api/beta/generate/image-to-image: + post: + tags: + - generate + summary: Image To Image + description: Apply image transformations to a provided image. + operationId: genImageToImage + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_genImageToImage' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ImageResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '401': + description: Unauthorized + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '422': + description: Validation Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPValidationError' + - $ref: '#/components/schemas/studio-api-error' + '500': + description: Internal Server Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + default: + description: Error + content: + application/json: + schema: + $ref: '#/components/schemas/studio-api-error' + security: + - HTTPBearer: [] + x-speakeasy-name-override: imageToImage + /api/beta/generate/image-to-video: + post: + tags: + - generate + summary: Image To Video + description: Generate a video from a provided image. + operationId: genImageToVideo + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_genImageToVideo' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/VideoResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '401': + description: Unauthorized + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '422': + description: Validation Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPValidationError' + - $ref: '#/components/schemas/studio-api-error' + '500': + description: Internal Server Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + default: + description: Error + content: + application/json: + schema: + $ref: '#/components/schemas/studio-api-error' + security: + - HTTPBearer: [] + x-speakeasy-name-override: imageToVideo + /api/beta/generate/upscale: + post: + tags: + - generate + summary: Upscale + description: Upscale an image by increasing its resolution. + operationId: genUpscale + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_genUpscale' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ImageResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '401': + description: Unauthorized + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '422': + description: Validation Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPValidationError' + - $ref: '#/components/schemas/studio-api-error' + '500': + description: Internal Server Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + default: + description: Error + content: + application/json: + schema: + $ref: '#/components/schemas/studio-api-error' + security: + - HTTPBearer: [] + x-speakeasy-name-override: upscale + /api/beta/generate/audio-to-text: + post: + tags: + - generate + summary: Audio To Text + description: Transcribe audio files to text. + operationId: genAudioToText + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_genAudioToText' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/TextResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '401': + description: Unauthorized + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '413': + description: Request Entity Too Large + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '422': + description: Validation Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPValidationError' + - $ref: '#/components/schemas/studio-api-error' + '500': + description: Internal Server Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + default: + description: Error + content: + application/json: + schema: + $ref: '#/components/schemas/studio-api-error' + security: + - HTTPBearer: [] + x-speakeasy-name-override: audioToText + /api/beta/generate/segment-anything-2: + post: + tags: + - generate + summary: Segment Anything 2 + description: Segment objects in an image. + operationId: genSegmentAnything2 + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_genSegmentAnything2' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/MasksResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '401': + description: Unauthorized + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '422': + description: Validation Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPValidationError' + - $ref: '#/components/schemas/studio-api-error' + '500': + description: Internal Server Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + default: + description: Error + content: + application/json: + schema: + $ref: '#/components/schemas/studio-api-error' + security: + - HTTPBearer: [] + x-speakeasy-name-override: segmentAnything2 +components: + schemas: + APIError: + properties: + msg: + type: string + title: Msg + description: The error message. + type: object + required: + - msg + title: APIError + description: API error response model. + Body_genAudioToText: + properties: + audio: + type: string + format: binary + title: Audio + description: Uploaded audio file to be transcribed. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for transcription. + default: openai/whisper-large-v3 + type: object + required: + - audio + title: Body_genAudioToText + additionalProperties: false + Body_genImageToImage: + properties: + prompt: + type: string + title: Prompt + description: Text prompt(s) to guide image generation. + image: + type: string + format: binary + title: Image + description: Uploaded image to modify with the pipeline. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for image generation. + default: timbrooks/instruct-pix2pix + strength: + type: number + title: Strength + description: Degree of transformation applied to the reference image (0 to 1). + default: 0.8 + guidance_scale: + type: number + title: Guidance Scale + description: >- + Encourages model to generate images closely linked to the text + prompt (higher values may reduce image quality). + default: 7.5 + image_guidance_scale: + type: number + title: Image Guidance Scale + description: >- + Degree to which the generated image is pushed towards the initial + image. + default: 1.5 + negative_prompt: + type: string + title: Negative Prompt + description: >- + Text prompt(s) to guide what to exclude from image generation. + Ignored if guidance_scale < 1. + default: '' + safety_check: + type: boolean + title: Safety Check + description: >- + Perform a safety check to estimate if generated images could be + offensive or harmful. + default: true + seed: + type: integer + title: Seed + description: Seed for random number generation. + num_inference_steps: + type: integer + title: Num Inference Steps + description: >- + Number of denoising steps. More steps usually lead to higher quality + images but slower inference. Modulated by strength. + default: 100 + num_images_per_prompt: + type: integer + title: Num Images Per Prompt + description: Number of images to generate per prompt. + default: 1 + type: object + required: + - prompt + - image + title: Body_genImageToImage + additionalProperties: false + Body_genImageToVideo: + properties: + image: + type: string + format: binary + title: Image + description: Uploaded image to generate a video from. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for video generation. + default: stabilityai/stable-video-diffusion-img2vid-xt-1-1 + height: + type: integer + title: Height + description: The height in pixels of the generated video. + default: 576 + width: + type: integer + title: Width + description: The width in pixels of the generated video. + default: 1024 + fps: + type: integer + title: Fps + description: The frames per second of the generated video. + default: 6 + motion_bucket_id: + type: integer + title: Motion Bucket Id + description: >- + Used for conditioning the amount of motion for the generation. The + higher the number the more motion will be in the video. + default: 127 + noise_aug_strength: + type: number + title: Noise Aug Strength + description: >- + Amount of noise added to the conditioning image. Higher values + reduce resemblance to the conditioning image and increase motion. + default: 0.02 + safety_check: + type: boolean + title: Safety Check + description: >- + Perform a safety check to estimate if generated images could be + offensive or harmful. + default: true + seed: + type: integer + title: Seed + description: Seed for random number generation. + num_inference_steps: + type: integer + title: Num Inference Steps + description: >- + Number of denoising steps. More steps usually lead to higher quality + images but slower inference. Modulated by strength. + default: 25 + type: object + required: + - image + title: Body_genImageToVideo + additionalProperties: false + Body_genSegmentAnything2: + properties: + image: + type: string + format: binary + title: Image + description: Image to segment. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for image generation. + default: 'facebook/sam2-hiera-large:' + point_coords: + type: string + title: Point Coords + description: >- + Nx2 array of point prompts to the model, where each point is in + (X,Y) in pixels. + point_labels: + type: string + title: Point Labels + description: >- + Labels for the point prompts, where 1 indicates a foreground point + and 0 indicates a background point. + box: + type: string + title: Box + description: 'A length 4 array given as a box prompt to the model, in XYXY format.' + mask_input: + type: string + title: Mask Input + description: >- + A low-resolution mask input to the model, typically from a previous + prediction iteration, with the form 1xHxW (H=W=256 for SAM). + multimask_output: + type: boolean + title: Multimask Output + description: >- + If true, the model will return three masks for ambiguous input + prompts, often producing better masks than a single prediction. + default: true + return_logits: + type: boolean + title: Return Logits + description: >- + If true, returns un-thresholded mask logits instead of a binary + mask. + default: true + normalize_coords: + type: boolean + title: Normalize Coords + description: >- + If true, the point coordinates will be normalized to the range + [0,1], with point_coords expected to be with respect to image + dimensions. + default: true + type: object + required: + - image + title: Body_genSegmentAnything2 + additionalProperties: false + Body_genUpscale: + properties: + prompt: + type: string + title: Prompt + description: Text prompt(s) to guide upscaled image generation. + image: + type: string + format: binary + title: Image + description: Uploaded image to modify with the pipeline. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for upscaled image generation. + default: stabilityai/stable-diffusion-x4-upscaler + safety_check: + type: boolean + title: Safety Check + description: >- + Perform a safety check to estimate if generated images could be + offensive or harmful. + default: true + seed: + type: integer + title: Seed + description: Seed for random number generation. + num_inference_steps: + type: integer + title: Num Inference Steps + description: >- + Number of denoising steps. More steps usually lead to higher quality + images but slower inference. Modulated by strength. + default: 75 + type: object + required: + - prompt + - image + title: Body_genUpscale + additionalProperties: false + HTTPError: + properties: + detail: + allOf: + - $ref: '#/components/schemas/APIError' + description: Detailed error information. + type: object + required: + - detail + title: HTTPError + description: HTTP error response model. + HTTPValidationError: + properties: + detail: + items: + $ref: '#/components/schemas/ValidationError' + type: array + title: Detail + type: object + title: HTTPValidationError + ImageResponse: + properties: + images: + items: + $ref: '#/components/schemas/Media' + type: array + title: Images + description: The generated images. + type: object + required: + - images + title: ImageResponse + description: Response model for image generation. + MasksResponse: + properties: + masks: + type: string + title: Masks + description: The generated masks. + scores: + type: string + title: Scores + description: The model's confidence scores for each generated mask. + logits: + type: string + title: Logits + description: 'The raw, unnormalized predictions (logits) for the masks.' + type: object + required: + - masks + - scores + - logits + title: MasksResponse + description: Response model for object segmentation. + Media: + properties: + url: + type: string + title: Url + description: The URL where the media can be accessed. + seed: + type: integer + title: Seed + description: The seed used to generate the media. + nsfw: + type: boolean + title: Nsfw + description: Whether the media was flagged as NSFW. + type: object + required: + - url + - seed + - nsfw + title: Media + description: A media object containing information about the generated media. + TextResponse: + properties: + text: + type: string + title: Text + description: The generated text. + chunks: + items: + $ref: '#/components/schemas/chunk' + type: array + title: Chunks + description: The generated text chunks. + type: object + required: + - text + - chunks + title: TextResponse + description: Response model for text generation. + TextToImageParams: + properties: + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for image generation. + default: SG161222/RealVisXL_V4.0_Lightning + prompt: + type: string + title: Prompt + description: >- + Text prompt(s) to guide image generation. Separate multiple prompts + with '|' if supported by the model. + height: + type: integer + title: Height + description: The height in pixels of the generated image. + default: 576 + width: + type: integer + title: Width + description: The width in pixels of the generated image. + default: 1024 + guidance_scale: + type: number + title: Guidance Scale + description: >- + Encourages model to generate images closely linked to the text + prompt (higher values may reduce image quality). + default: 7.5 + negative_prompt: + type: string + title: Negative Prompt + description: >- + Text prompt(s) to guide what to exclude from image generation. + Ignored if guidance_scale < 1. + default: '' + safety_check: + type: boolean + title: Safety Check + description: >- + Perform a safety check to estimate if generated images could be + offensive or harmful. + default: true + seed: + type: integer + title: Seed + description: Seed for random number generation. + num_inference_steps: + type: integer + title: Num Inference Steps + description: >- + Number of denoising steps. More steps usually lead to higher quality + images but slower inference. Modulated by strength. + default: 50 + num_images_per_prompt: + type: integer + title: Num Images Per Prompt + description: Number of images to generate per prompt. + default: 1 + type: object + required: + - prompt + title: TextToImageParams + additionalProperties: false + ValidationError: + properties: + loc: + items: + anyOf: + - type: string + - type: integer + type: array + title: Location + msg: + type: string + title: Message + type: + type: string + title: Error Type + type: object + required: + - loc + - msg + - type + title: ValidationError + VideoResponse: + properties: + images: + items: + $ref: '#/components/schemas/Media' + type: array + title: Images + description: The generated images. + type: object + required: + - images + title: VideoResponse + description: Response model for image generation. + chunk: + properties: + timestamp: + items: {} + type: array + title: Timestamp + description: The timestamp of the chunk. + text: + type: string + title: Text + description: The text of the chunk. + type: object + required: + - timestamp + - text + title: chunk + description: A chunk of text with a timestamp. + studio-api-error: + type: object + properties: + errors: + type: array + minItems: 1 + items: + type: string + securitySchemes: + HTTPBearer: + type: http + scheme: bearer diff --git a/packages/api/src/schema/api-schema.yaml b/packages/api/src/schema/api-schema.yaml index ab4e8f8fe..376f76feb 100644 --- a/packages/api/src/schema/api-schema.yaml +++ b/packages/api/src/schema/api-schema.yaml @@ -33,6 +33,8 @@ tags: description: Operations related to access control/signing keys api - name: task description: Operations related to tasks api + - name: generate + description: Operations related to AI generate api components: securitySchemes: apiKey: @@ -2863,161 +2865,6 @@ components: targetSegmentSizeSecs: $ref: >- #/components/schemas/new-asset-payload/properties/targetSegmentSizeSecs - # AI Generate payloads. Keep in mind that these use snake_case instead of camelCase since - # they implement the same interface as the AI Gateway Livepeer node. - audio-to-text-payload: - type: object - required: - - audio - properties: - audio: - type: string - format: binary - maxLength: 10485760 # 10MiB - model_id: - type: string - default: openai/whisper-large-v3 - enum: - - openai/whisper-large-v3 - text-to-image-payload: - type: object - required: - - prompt - additionalProperties: false - properties: - prompt: - type: string - model_id: - type: string - default: SG161222/RealVisXL_V4.0_Lightning - enum: - - SG161222/RealVisXL_V4.0_Lightning - - ByteDance/SDXL-Lightning - height: - type: integer - width: - type: integer - guidance_scale: - type: number - default: 7.5 - negative_prompt: - type: string - default: "" - safety_check: - type: boolean - default: true - seed: - type: integer - num_inference_steps: - type: integer - default: 50 - minimum: 1 - maximum: 200 - num_images_per_prompt: - type: integer - default: 1 - minimum: 1 - maximum: 20 - image-to-image-payload: - type: object - required: - - prompt - - image - additionalProperties: false - properties: - prompt: - type: string - image: - type: string - format: binary - maxLength: 10485760 # 10MiB - model_id: - type: string - default: timbrooks/instruct-pix2pix - enum: - - timbrooks/instruct-pix2pix - - ByteDance/SDXL-Lightning - - SG161222/RealVisXL_V4.0_Lightning - strength: - type: number - default: 0.8 - guidance_scale: - type: number - default: 7.5 - image_guidance_scale: - type: number - default: 1.5 - negative_prompt: - type: string - default: "" - safety_check: - type: boolean - default: true - seed: - type: integer - num_images_per_prompt: - type: integer - default: 1 - minimum: 1 - maximum: 20 - image-to-video-payload: - type: object - required: - - image - additionalProperties: false - properties: - image: - type: string - format: binary - maxLength: 10485760 # 10MiB - model_id: - type: string - default: stabilityai/stable-video-diffusion-img2vid-xt-1-1 - enum: - - stabilityai/stable-video-diffusion-img2vid-xt-1-1 - height: - type: integer - default: 576 - width: - type: integer - default: 1024 - fps: - type: integer - default: 6 - motion_bucket_id: - type: integer - default: 127 - noise_aug_strength: - type: number - default: 0.02 - seed: - type: integer - safety_check: - type: boolean - default: true - upscale-payload: - type: object - required: - - prompt - - image - additionalProperties: false - properties: - prompt: - type: string - image: - type: string - format: binary - maxLength: 10485760 # 10MiB - model_id: - type: string - default: stabilityai/stable-diffusion-x4-upscaler - enum: - - stabilityai/stable-diffusion-x4-upscaler - safety_check: - type: boolean - default: true - seed: - type: integer paths: /stream: post: @@ -5119,3 +4966,9 @@ paths: application/json: schema: $ref: "#/components/schemas/error" + default: + description: Error + content: + application/json: + schema: + $ref: "#/components/schemas/error" diff --git a/packages/api/src/compile-schemas.js b/packages/api/src/schema/compile-schemas.js similarity index 74% rename from packages/api/src/compile-schemas.js rename to packages/api/src/schema/compile-schemas.js index 4b09fd9b6..43fe5c2c3 100644 --- a/packages/api/src/compile-schemas.js +++ b/packages/api/src/schema/compile-schemas.js @@ -24,23 +24,47 @@ const write = (dir, data) => { console.log(`wrote ${dir}`); }; -const schemaDir = path.resolve(__dirname, "schema"); +// Remove the title from the schema to avoid conflicts with the TypeScript type name +function removeAllTitles(schema) { + if (schema.title) { + delete schema.title; + } + + if (schema.properties) { + for (const key in schema.properties) { + if (schema.properties[key]) { + schema.properties[key] = removeAllTitles(schema.properties[key]); + } + } + } + + if (schema.items) { + if (Array.isArray(schema.items)) { + schema.items = schema.items.map((item) => removeAllTitles(item)); + } else { + schema.items = removeAllTitles(schema.items); + } + } + + return schema; +} + +const schemaDir = path.resolve(__dirname, "."); +process.chdir(schemaDir); + const validatorDir = path.resolve(schemaDir, "validators"); const schemaDistDir = path.resolve(__dirname, "..", "dist", "schema"); fs.ensureDirSync(validatorDir); fs.ensureDirSync(schemaDistDir); -const apiSchemaStr = fs.readFileSync( - path.resolve(schemaDir, "api-schema.yaml"), - "utf8", -); -const dbSchemaStr = fs.readFileSync( - path.resolve(schemaDir, "db-schema.yaml"), - "utf8", -); -const apiData = parseYaml(apiSchemaStr); -const dbData = parseYaml(dbSchemaStr); -const data = _.merge({}, apiData, dbData); +const schemaFiles = ["api-schema.yaml", "ai-api-schema.yaml", "db-schema.yaml"]; +const subSchemas = []; +for (const file of schemaFiles) { + const schemaStr = fs.readFileSync(path.resolve(schemaDir, file), "utf8"); + const data = parseYaml(schemaStr); + subSchemas.push(data); +} +const data = _.merge({}, ...subSchemas); (async () => { const yaml = serializeYaml(data); @@ -65,7 +89,8 @@ const data = _.merge({}, apiData, dbData); const index = []; let types = []; - for (const [name, schema] of Object.entries(data.components.schemas)) { + for (let [name, schema] of Object.entries(data.components.schemas)) { + schema = removeAllTitles(schema); schema.title = name; const type = await generateTypes(schema); types.push(type); diff --git a/packages/api/src/schema/db-schema.yaml b/packages/api/src/schema/db-schema.yaml index 9df176762..6388c00a1 100644 --- a/packages/api/src/schema/db-schema.yaml +++ b/packages/api/src/schema/db-schema.yaml @@ -1453,12 +1453,14 @@ components: - image-to-image - image-to-video - upscale + - segment-anything-2 request: oneOf: - - $ref: "#/components/schemas/text-to-image-payload" - - $ref: "#/components/schemas/image-to-image-payload" - - $ref: "#/components/schemas/image-to-video-payload" - - $ref: "#/components/schemas/upscale-payload" + - $ref: "./ai-api-schema.yaml#/components/schemas/TextToImageParams" + - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genImageToImage" + - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genImageToVideo" + - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genUpscale" + - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genSegmentAnything2" statusCode: type: integer description: HTTP status code received from the AI gateway diff --git a/packages/api/src/schema/pull-ai-schema.js b/packages/api/src/schema/pull-ai-schema.js new file mode 100644 index 000000000..be0443f2a --- /dev/null +++ b/packages/api/src/schema/pull-ai-schema.js @@ -0,0 +1,139 @@ +import fs from "fs-extra"; +import { safeLoad as parseYaml, safeDump as serializeYaml } from "js-yaml"; +import path from "path"; + +// This downloads the AI schema from the AI worker repo and saves in the local +// ai-api-schema.yaml file, referenced by our main api-schema.yaml file. + +export const defaultModels = { + "text-to-image": "SG161222/RealVisXL_V4.0_Lightning", + "image-to-image": "timbrooks/instruct-pix2pix", + "image-to-video": "stabilityai/stable-video-diffusion-img2vid-xt-1-1", + upscale: "stabilityai/stable-diffusion-x4-upscaler", + "audio-to-text": "openai/whisper-large-v3", + "segment-anything-2": "facebook/sam2-hiera-large:", +}; +const schemaDir = path.resolve(__dirname, "."); +const aiSchemaUrl = + "https://raw.githubusercontent.com/livepeer/ai-worker/refs/heads/main/runner/gateway.openapi.yaml"; + +const studioApiErrorSchema = { + type: "object", + properties: { + errors: { + type: "array", + minItems: 1, + items: { + type: "string", + }, + }, + }, +}; + +const write = (dir, data) => { + if (fs.existsSync(dir)) { + const existing = fs.readFileSync(dir, "utf8"); + if (existing === data) { + return; + } + } + fs.writeFileSync(dir, data, "utf8"); + console.log(`wrote ${dir}`); +}; + +const mapObject = (obj, fn) => { + return Object.fromEntries( + Object.entries(obj).map(([key, value]) => fn(key, value)), + ); +}; + +const downloadAiSchema = async () => { + // download the file + const response = await fetch(aiSchemaUrl); + const data = await response.text(); + const schema = parseYaml(data); + + // remove info and servers fields + delete schema.info; + delete schema.servers; + + // add studio-api-error schema + schema.components.schemas["studio-api-error"] = studioApiErrorSchema; + + // patches to the paths section + schema.paths = mapObject(schema.paths, (path, value) => { + // prefix paths with /api/beta/generate + path = `/api/beta/generate${path}`; + // remove security field + delete value.security; + // add Studio API error as oneOf to all of the error responses + const studioErrorRef = () => ({ + $ref: "#/components/schemas/studio-api-error", + }); + value.post.responses = mapObject( + value.post.responses, + (statusCode, response) => { + if ( + statusCode !== "default" && + Math.floor(parseInt(statusCode) / 100) === 2 + ) { + return [statusCode, response]; + } + response.content["application/json"].schema = { + oneOf: [ + response.content["application/json"].schema, + studioErrorRef(), + ], + }; + return [statusCode, response]; + }, + ); + // add $ref: "#/components/schemas/error" as the default response + if (!value.post.responses["default"]) { + value.post.responses["default"] = { + description: "Error", + content: { "application/json": { schema: studioErrorRef() } }, + }; + } + return [path, value]; + }); + + // Modify the pipeline input schemas to: + // - set default model_id values for each pipeline (and make them not requried) + // - disallow additionalProperties + schema.components.schemas = mapObject( + schema.components.schemas, + (key, value) => { + let pipelineName; + if (key.endsWith("Params")) { + pipelineName = key.slice(0, -6); + } else if (key.startsWith("Body_gen")) { + pipelineName = key.slice(8); + } else { + return [key, value]; + } + // turn CamelCase3 to kebab-case-3 + pipelineName = pipelineName + .replace(/([a-z])([A-Z0-9])/g, "$1-$2") + .toLowerCase(); + + if (pipelineName in defaultModels && value.properties.model_id) { + value.properties.model_id.default = defaultModels[pipelineName]; + if (value.required) { + value.required = value.required.filter((key) => key !== "model_id"); + } + } + value.additionalProperties = false; + + return [key, value]; + }, + ); + + const yaml = serializeYaml(schema); + write(path.resolve(schemaDir, "ai-api-schema.yaml"), yaml); +}; + +downloadAiSchema().catch((err) => { + console.error(err); + process.exit(1); +});