From cac159fb41c3b7700643b3fb67b192617634dfea Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Sun, 14 Jan 2024 22:15:09 +0000 Subject: [PATCH] Generate TypeScript definitions from source --- index.d.ts | 226 ----------------------- index.js | 70 ++++++- index.test.ts | 3 +- integration/commonjs/package-lock.json | 2 + integration/esm/package-lock.json | 1 + integration/typescript/package-lock.json | 5 +- integration/typescript/types.test.ts | 84 +++++++++ jsconfig.json | 4 +- lib/collections.js | 10 +- lib/deployments.js | 8 +- lib/hardware.js | 3 +- lib/identifier.js | 12 +- lib/models.js | 29 ++- lib/predictions.js | 40 ++-- lib/stream.js | 10 +- lib/trainings.js | 16 +- lib/types.js | 71 +++++++ package.json | 7 +- tsconfig.json | 6 +- 19 files changed, 323 insertions(+), 284 deletions(-) delete mode 100644 index.d.ts create mode 100644 integration/typescript/types.test.ts create mode 100644 lib/types.js diff --git a/index.d.ts b/index.d.ts deleted file mode 100644 index 5620f3b..0000000 --- a/index.d.ts +++ /dev/null @@ -1,226 +0,0 @@ -declare module "replicate" { - type Status = "starting" | "processing" | "succeeded" | "failed" | "canceled"; - type Visibility = "public" | "private"; - type WebhookEventType = "start" | "output" | "logs" | "completed"; - - export interface ApiError extends Error { - request: Request; - response: Response; - } - - export interface Collection { - name: string; - slug: string; - description: string; - models?: Model[]; - } - - export interface Hardware { - sku: string; - name: string; - } - - export interface Model { - url: string; - owner: string; - name: string; - description?: string; - visibility: "public" | "private"; - github_url?: string; - paper_url?: string; - license_url?: string; - run_count: number; - cover_image_url?: string; - default_example?: Prediction; - latest_version?: ModelVersion; - } - - export interface ModelVersion { - id: string; - created_at: string; - cog_version: string; - openapi_schema: object; - } - - export interface Prediction { - id: string; - status: Status; - model: string; - version: string; - input: object; - output?: any; - source: "api" | "web"; - error?: any; - logs?: string; - metrics?: { - predict_time?: number; - }; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - created_at: string; - started_at?: string; - completed_at?: string; - urls: { - get: string; - cancel: string; - stream?: string; - }; - } - - export type Training = Prediction; - - export interface Page { - previous?: string; - next?: string; - results: T[]; - } - - export interface ServerSentEvent { - event: string; - data: string; - id?: string; - retry?: number; - } - - export default class Replicate { - constructor(options?: { - auth?: string; - userAgent?: string; - baseUrl?: string; - fetch?: ( - input: Request | string, - init?: RequestInit - ) => Promise; - }); - - auth: string; - userAgent?: string; - baseUrl?: string; - fetch: (input: Request | string, init?: RequestInit) => Promise; - - run( - identifier: `${string}/${string}` | `${string}/${string}:${string}`, - options: { - input: object; - wait?: { interval?: number }; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - signal?: AbortSignal; - }, - progress?: (prediction: Prediction) => void - ): Promise; - - stream( - identifier: `${string}/${string}` | `${string}/${string}:${string}`, - options: { - input: object; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - signal?: AbortSignal; - } - ): AsyncGenerator; - - request( - route: string | URL, - options: { - method?: string; - headers?: object | Headers; - params?: object; - data?: object; - } - ): Promise; - - paginate(endpoint: () => Promise>): AsyncGenerator<[T]>; - - wait( - prediction: Prediction, - options?: { - interval?: number; - }, - stop?: (prediction: Prediction) => Promise - ): Promise; - - collections: { - list(): Promise>; - get(collection_slug: string): Promise; - }; - - deployments: { - predictions: { - create( - deployment_owner: string, - deployment_name: string, - options: { - input: object; - stream?: boolean; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - }; - }; - - hardware: { - list(): Promise; - }; - - models: { - get(model_owner: string, model_name: string): Promise; - list(): Promise>; - create( - model_owner: string, - model_name: string, - options: { - visibility: Visibility; - hardware: string; - description?: string; - github_url?: string; - paper_url?: string; - license_url?: string; - cover_image_url?: string; - } - ): Promise; - versions: { - list(model_owner: string, model_name: string): Promise; - get( - model_owner: string, - model_name: string, - version_id: string - ): Promise; - }; - }; - - predictions: { - create( - options: { - model?: string; - version?: string; - input: object; - stream?: boolean; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } & ({ version: string } | { model: string }) - ): Promise; - get(prediction_id: string): Promise; - cancel(prediction_id: string): Promise; - list(): Promise>; - }; - - trainings: { - create( - model_owner: string, - model_name: string, - version_id: string, - options: { - destination: `${string}/${string}`; - input: object; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - get(training_id: string): Promise; - cancel(training_id: string): Promise; - list(): Promise>; - }; - } -} diff --git a/index.js b/index.js index ce407f9..eb4d33f 100644 --- a/index.js +++ b/index.js @@ -34,34 +34,59 @@ class Replicate { /** * Create a new Replicate API client instance. * - * @param {object} options - Configuration options for the client - * @param {string} options.auth - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable. - * @param {string} options.userAgent - Identifier of your app + * @example + * // Create a new Replicate API client instance + * const Replicate = require("replicate"); + * const replicate = new Replicate({ + * // get your token from https://replicate.com/account + * auth: process.env.REPLICATE_API_TOKEN, + * userAgent: "my-app/1.2.3" + * }); + * + * // Run a model and await the result: + * const model = 'owner/model:version-id' + * const input = {text: 'Hello, world!'} + * const output = await replicate.run(model, { input }); + * + * @param {Object} [options] - Configuration options for the client + * @param {string} [options.auth] - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable. + * @param {string} [options.userAgent] - Identifier of your app * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` */ constructor(options = {}) { + /** @type {string} */ this.auth = options.auth || process.env.REPLICATE_API_TOKEN; + + /** @type {string} */ this.userAgent = options.userAgent || `replicate-javascript/${packageJSON.version}`; + + /** @type {string} */ this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; + + /** @type {fetch} */ this.fetch = options.fetch || globalThis.fetch; + /** @type {collections} */ this.collections = { list: collections.list.bind(this), get: collections.get.bind(this), }; + /** @type {deployments} */ this.deployments = { predictions: { create: deployments.predictions.create.bind(this), }, }; + /** @type {hardware} */ this.hardware = { list: hardware.list.bind(this), }; + /** @type {models} */ this.models = { get: models.get.bind(this), list: models.list.bind(this), @@ -72,6 +97,7 @@ class Replicate { }, }; + /** @type {predictions} */ this.predictions = { create: predictions.create.bind(this), get: predictions.get.bind(this), @@ -79,6 +105,7 @@ class Replicate { list: predictions.list.bind(this), }; + /** @type {trainings} */ this.trainings = { create: trainings.create.bind(this), get: trainings.get.bind(this), @@ -90,18 +117,18 @@ class Replicate { /** * Run a model and wait for its output. * - * @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version" + * @param {`${string}/${string}` | `${string}/${string}:${string}`} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version" * @param {object} options * @param {object} options.input - Required. An object with the model inputs * @param {object} [options.wait] - Options for waiting for the prediction to finish * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500 * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction * @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed. * @throws {Error} If the reference is invalid * @throws {Error} If the prediction failed - * @returns {Promise} - Resolves with the output of running the model + * @returns {Promise} - Resolves with the output of running the model */ async run(ref, options, progress) { const { wait, ...data } = options; @@ -237,7 +264,7 @@ class Replicate { /** * Stream a model and wait for its output. * - * @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}" + * @param {string} ref - Required. The model version identifier in the format "{owner}/{name}:{version}" * @param {object} options * @param {object} options.input - Required. An object with the model inputs * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output @@ -285,8 +312,10 @@ class Replicate { * for await (const page of replicate.paginate(replicate.predictions.list) { * console.log(page); * } - * @param {Function} endpoint - Function that returns a promise for the next page of results - * @yields {object[]} Each page of results + * @template T + * @param {() => Promise>} endpoint - Function that returns a promise for the next page of results + * @yields {T[]} Each page of results + * @returns {AsyncGenerator} */ async *paginate(endpoint) { const response = await endpoint(); @@ -312,7 +341,7 @@ class Replicate { * @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling. * @throws {Error} If the prediction doesn't complete within the maximum number of attempts * @throws {Error} If the prediction failed - * @returns {Promise} Resolves with the completed prediction object + * @returns {Promise} Resolves with the completed prediction object */ async wait(prediction, options, stop) { const { id } = prediction; @@ -359,3 +388,24 @@ class Replicate { } module.exports = Replicate; + +// - Type Definitions + +/** + * @typedef {import("./lib/error")} ApiError + * @typedef {import("./lib/types").Collection} Collection + * @typedef {import("./lib/types").ModelVersion} ModelVersion + * @typedef {import("./lib/types").Hardware} Hardware + * @typedef {import("./lib/types").Model} Model + * @typedef {import("./lib/types").Prediction} Prediction + * @typedef {import("./lib/types").Training} Training + * @typedef {import("./lib/types").ServerSentEvent} ServerSentEvent + * @typedef {import("./lib/types").Status} Status + * @typedef {import("./lib/types").Visibility} Visibility + * @typedef {import("./lib/types").WebhookEventType} WebhookEventType + */ + +/** + * @template T + * @typedef {import("./lib/types").Page} Page + */ diff --git a/index.test.ts b/index.test.ts index 5b5a1dd..8d36880 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,5 +1,5 @@ import { expect, jest, test } from "@jest/globals"; -import Replicate, { ApiError, Model, Prediction } from "replicate"; +import Replicate, { ApiError, Model, Prediction } from "./"; import nock from "nock"; import fetch from "cross-fetch"; @@ -838,7 +838,6 @@ describe("Replicate client", () => { }); test("Calls the correct API routes for a model", async () => { - const firstPollingRequest = true; nock(BASE_URL) .post("/models/replicate/hello-world/predictions") diff --git a/integration/commonjs/package-lock.json b/integration/commonjs/package-lock.json index 1584af5..5a3fc99 100644 --- a/integration/commonjs/package-lock.json +++ b/integration/commonjs/package-lock.json @@ -12,6 +12,7 @@ } }, "../..": { + "name": "replicate", "version": "0.25.2", "license": "Apache-2.0", "devDependencies": { @@ -21,6 +22,7 @@ "cross-fetch": "^3.1.5", "jest": "^29.6.2", "nock": "^13.3.0", + "publint": "^0.2.7", "ts-jest": "^29.1.0", "typescript": "^5.0.2" }, diff --git a/integration/esm/package-lock.json b/integration/esm/package-lock.json index 2a17c88..433b81f 100644 --- a/integration/esm/package-lock.json +++ b/integration/esm/package-lock.json @@ -22,6 +22,7 @@ "cross-fetch": "^3.1.5", "jest": "^29.6.2", "nock": "^13.3.0", + "publint": "^0.2.7", "ts-jest": "^29.1.0", "typescript": "^5.0.2" }, diff --git a/integration/typescript/package-lock.json b/integration/typescript/package-lock.json index f309b1b..7e036ff 100644 --- a/integration/typescript/package-lock.json +++ b/integration/typescript/package-lock.json @@ -1,11 +1,11 @@ { - "name": "replicate-app-esm", + "name": "replicate-app-typescript", "version": "0.0.0", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "replicate-app-esm", + "name": "replicate-app-typescript", "version": "0.0.0", "dependencies": { "@types/node": "^20.11.0", @@ -24,6 +24,7 @@ "cross-fetch": "^3.1.5", "jest": "^29.6.2", "nock": "^13.3.0", + "publint": "^0.2.7", "ts-jest": "^29.1.0", "typescript": "^5.0.2" }, diff --git a/integration/typescript/types.test.ts b/integration/typescript/types.test.ts new file mode 100644 index 0000000..d58484b --- /dev/null +++ b/integration/typescript/types.test.ts @@ -0,0 +1,84 @@ +import { ApiError, Collection, Hardware, Model, ModelVersion, Page, Prediction, Status, Training, Visibility, WebhookEventType } from "replicate"; + +export type Equals = + (() => T extends X ? 1 : 2) extends + (() => T extends Y ? 1 : 2) ? true : false; + + +type AssertFalse = A + +// @ts-expect-error +export type TestAssertion = AssertFalse> + +export type TestApiError = AssertFalse> +export type TestCollection = AssertFalse> +export type TestHardware = AssertFalse> +export type TestModel = AssertFalse> +export type TestModelVersion = AssertFalse> +export type TestPage = AssertFalse, any>> +export type TestPrediction = AssertFalse> +export type TestStatus = AssertFalse> +export type TestTraining = AssertFalse> +export type TestVisibility = AssertFalse> +export type TestWebhookEventType = AssertFalse> + + +// NOTE: We export the constants to avoid unused varaible issues. + +export const collection: Collection = { name: "", slug: "", description: "", models: [] }; +export const status: Status = "starting"; +export const visibility: Visibility = "public"; +export const webhookType: WebhookEventType = "start"; +export const err: ApiError = Object.assign(new Error(), {request: new Request("file://"), response: new Response()}); +export const hardware: Hardware = { sku: "", name: "" }; +export const model: Model = { + url: "", + owner: "", + name: "", + description: "", + visibility: "public", + github_url: "", + paper_url: "", + license_url: "", + run_count: 10, + cover_image_url: "", + default_example: undefined, + latest_version: undefined, +}; +export const version: ModelVersion = { + id: "", + created_at: "", + cog_version: "", + openapi_schema: "", +}; +export const prediction: Prediction = { + id: "", + status: "starting", + model: "", + version: "", + input: {}, + output: {}, + source: "api", + error: undefined, + logs: "", + metrics: { + predict_time: 100, + }, + webhook: "", + webhook_events_filter: [], + created_at: "", + started_at: "", + completed_at: "", + urls: { + get: "", + cancel: "", + stream: "", + }, +}; +export const training: Training = prediction; + +export const page: Page = { + previous: "", + next: "", + results: [version], +}; diff --git a/jsconfig.json b/jsconfig.json index b83b3f3..3d6fa2f 100644 --- a/jsconfig.json +++ b/jsconfig.json @@ -6,9 +6,11 @@ "target": "ES2020", "resolveJsonModule": true, "strictNullChecks": true, - "strictFunctionTypes": true + "strictFunctionTypes": true, + "types": [], }, "exclude": [ + "dist", "node_modules", "**/node_modules/*" ] diff --git a/lib/collections.js b/lib/collections.js index 9332aaa..4175934 100644 --- a/lib/collections.js +++ b/lib/collections.js @@ -1,8 +1,14 @@ +/** @typedef {import("./types").Collection} Collection */ +/** + * @template T + * @typedef {import("./types").Page} Page + */ + /** * Fetch a model collection * * @param {string} collection_slug - Required. The slug of the collection. See http://replicate.com/collections - * @returns {Promise} - Resolves with the collection data + * @returns {Promise} - Resolves with the collection data */ async function getCollection(collection_slug) { const response = await this.request(`/collections/${collection_slug}`, { @@ -15,7 +21,7 @@ async function getCollection(collection_slug) { /** * Fetch a list of model collections * - * @returns {Promise} - Resolves with the collections data + * @returns {Promise>} - Resolves with the collections data */ async function listCollections() { const response = await this.request("/collections", { diff --git a/lib/deployments.js b/lib/deployments.js index 6f32cdb..c071e1d 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -1,14 +1,16 @@ +/** @typedef {import("./types").Prediction} Prediction */ + /** * Create a new prediction with a deployment * * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment * @param {string} deployment_name - Required. The name of the deployment * @param {object} options - * @param {object} options.input - Required. An object with the model inputs + * @param {unknown} options.input - Required. An object with the model inputs * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @returns {Promise} Resolves with the created prediction data + * @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @returns {Promise} Resolves with the created prediction data */ async function createPrediction(deployment_owner, deployment_name, options) { const { stream, ...data } = options; diff --git a/lib/hardware.js b/lib/hardware.js index d717548..755bd91 100644 --- a/lib/hardware.js +++ b/lib/hardware.js @@ -1,7 +1,8 @@ +/** @typedef {import("./types").Hardware} Hardware */ /** * List hardware * - * @returns {Promise} Resolves with the array of hardware + * @returns {Promise} Resolves with the array of hardware */ async function listHardware() { const response = await this.request("/hardware", { diff --git a/lib/identifier.js b/lib/identifier.js index 86e23ee..f9e9786 100644 --- a/lib/identifier.js +++ b/lib/identifier.js @@ -2,10 +2,10 @@ * A reference to a model version in the format `owner/name` or `owner/name:version`. */ class ModelVersionIdentifier { - /* - * @param {string} Required. The model owner. - * @param {string} Required. The model name. - * @param {string} The model version. + /** + * @param {string} owner Required. The model owner. + * @param {string} name Required. The model name. + * @param {string | null=} version The model version. */ constructor(owner, name, version = null) { this.owner = owner; @@ -13,10 +13,10 @@ class ModelVersionIdentifier { this.version = version; } - /* + /** * Parse a reference to a model version * - * @param {string} + * @param {string} ref * @returns {ModelVersionIdentifier} * @throws {Error} If the reference is invalid. */ diff --git a/lib/models.js b/lib/models.js index c6a02fc..e7cbcd8 100644 --- a/lib/models.js +++ b/lib/models.js @@ -1,9 +1,18 @@ +/** @typedef {import("./types").Model} Model */ +/** @typedef {import("./types").ModelVersion} ModelVersion */ +/** @typedef {import("./types").Prediction} Prediction */ +/** @typedef {import("./types").Visibility} Visibility */ +/** + * @template T + * @typedef {import("./types").Page} Page + */ + /** * Get information about a model * * @param {string} model_owner - Required. The name of the user or organization that owns the model * @param {string} model_name - Required. The name of the model - * @returns {Promise} Resolves with the model data + * @returns {Promise} Resolves with the model data */ async function getModel(model_owner, model_name) { const response = await this.request(`/models/${model_owner}/${model_name}`, { @@ -18,7 +27,7 @@ async function getModel(model_owner, model_name) { * * @param {string} model_owner - Required. The name of the user or organization that owns the model * @param {string} model_name - Required. The name of the model - * @returns {Promise} Resolves with the list of model versions + * @returns {Promise>} Resolves with the list of model versions */ async function listModelVersions(model_owner, model_name) { const response = await this.request( @@ -37,7 +46,7 @@ async function listModelVersions(model_owner, model_name) { * @param {string} model_owner - Required. The name of the user or organization that owns the model * @param {string} model_name - Required. The name of the model * @param {string} version_id - Required. The model version - * @returns {Promise} Resolves with the model version data + * @returns {Promise} Resolves with the model version data */ async function getModelVersion(model_owner, model_name, version_id) { const response = await this.request( @@ -53,7 +62,7 @@ async function getModelVersion(model_owner, model_name, version_id) { /** * List all public models * - * @returns {Promise} Resolves with the model version data + * @returns {Promise>} Resolves with the model version data */ async function listModels() { const response = await this.request("/models", { @@ -69,14 +78,14 @@ async function listModels() { * @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. * @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization. * @param {object} options - * @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. + * @param {Visibility} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. * @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`. * @param {string} options.description - A description of the model. - * @param {string} options.github_url - A URL for the model's source code on GitHub. - * @param {string} options.paper_url - A URL for the model's paper. - * @param {string} options.license_url - A URL for the model's license. - * @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file. - * @returns {Promise} Resolves with the model version data + * @param {string=} options.github_url - A URL for the model's source code on GitHub. + * @param {string=} options.paper_url - A URL for the model's paper. + * @param {string=} options.license_url - A URL for the model's license. + * @param {string=} options.cover_image_url - A URL for the model's cover image. This should be an image file. + * @returns {Promise} Resolves with the model version data */ async function createModel(model_owner, model_name, options) { const data = { owner: model_owner, name: model_name, ...options }; diff --git a/lib/predictions.js b/lib/predictions.js index 294e8d9..23ad5df 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -1,14 +1,30 @@ +/** + * @template T + * @typedef {import("./types").Page} Page + */ +/** @typedef {import("./types").Prediction} Prediction */ + +/** + * @typedef {Object} BasePredictionOptions + * @property {unknown} input - Required. An object with the model inputs + * @property {string} [webhook] - An HTTPS URL for receiving a webhook when the prediction has new output + * @property {string[]} [webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @property {boolean} [stream] - Whether to stream the prediction output. Defaults to false + * + * @typedef {Object} ModelPredictionOptions + * @property {string} model The model name (for official models) + * @property {never=} version + * + * @typedef {Object} VersionPredictionOptions + * @property {string} version The model version + * @property {never=} model + */ + /** * Create a new prediction * - * @param {object} options - * @param {string} options.model - The model. - * @param {string} options.version - The model version. - * @param {object} options.input - Required. An object with the model inputs - * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false - * @returns {Promise} Resolves with the created prediction + * @param {BasePredictionOptions & (ModelPredictionOptions | VersionPredictionOptions)} options + * @returns {Promise} Resolves with the created prediction */ async function createPrediction(options) { const { model, version, stream, ...data } = options; @@ -43,8 +59,8 @@ async function createPrediction(options) { /** * Fetch a prediction by ID * - * @param {number} prediction_id - Required. The prediction ID - * @returns {Promise} Resolves with the prediction data + * @param {string} prediction_id - Required. The prediction ID + * @returns {Promise} Resolves with the prediction data */ async function getPrediction(prediction_id) { const response = await this.request(`/predictions/${prediction_id}`, { @@ -58,7 +74,7 @@ async function getPrediction(prediction_id) { * Cancel a prediction by ID * * @param {string} prediction_id - Required. The training ID - * @returns {Promise} Resolves with the data for the training + * @returns {Promise} Resolves with the data for the training */ async function cancelPrediction(prediction_id) { const response = await this.request(`/predictions/${prediction_id}/cancel`, { @@ -71,7 +87,7 @@ async function cancelPrediction(prediction_id) { /** * List all predictions * - * @returns {Promise} - Resolves with a page of predictions + * @returns {Promise>} - Resolves with a page of predictions */ async function listPredictions() { const response = await this.request("/predictions", { diff --git a/lib/stream.js b/lib/stream.js index 012d6d0..ca1c38a 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -1,4 +1,5 @@ // Attempt to use readable-stream if available, attempt to use the built-in stream module. +/** @type {import("stream").Readable} */ let Readable; try { Readable = require("readable-stream").Readable; @@ -49,7 +50,7 @@ class Stream extends Readable { * Create a new stream of server-sent events. * * @param {string} url The URL to connect to. - * @param {object} options The fetch options. + * @param {RequestInit=} options The fetch options. */ constructor(url, options) { if (!Readable) { @@ -63,11 +64,18 @@ class Stream extends Readable { this.options = options; this.event = null; + + /** @type {unknown[]} */ this.data = []; + + /** @type {string | null} */ this.lastEventId = null; + + /** @type {number | null} */ this.retry = null; } + /** @param {string=} line */ decode(line) { if (!line) { if (!this.event && !this.data.length && !this.lastEventId) { diff --git a/lib/trainings.js b/lib/trainings.js index 6b13dca..e469b96 100644 --- a/lib/trainings.js +++ b/lib/trainings.js @@ -1,3 +1,9 @@ +/** + * @template T + * @typedef {import("./types").Page} Page + */ +/** @typedef {import("./types").Training} Training */ + /** * Create a new training * @@ -6,10 +12,10 @@ * @param {string} version_id - Required. The version ID * @param {object} options * @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}" - * @param {object} options.input - Required. An object with the model inputs + * @param {unknown} options.input - Required. An object with the model inputs * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @returns {Promise} Resolves with the data for the created training + * @returns {Promise} Resolves with the data for the created training */ async function createTraining(model_owner, model_name, version_id, options) { const { ...data } = options; @@ -38,7 +44,7 @@ async function createTraining(model_owner, model_name, version_id, options) { * Fetch a training by ID * * @param {string} training_id - Required. The training ID - * @returns {Promise} Resolves with the data for the training + * @returns {Promise} Resolves with the data for the training */ async function getTraining(training_id) { const response = await this.request(`/trainings/${training_id}`, { @@ -52,7 +58,7 @@ async function getTraining(training_id) { * Cancel a training by ID * * @param {string} training_id - Required. The training ID - * @returns {Promise} Resolves with the data for the training + * @returns {Promise} Resolves with the data for the training */ async function cancelTraining(training_id) { const response = await this.request(`/trainings/${training_id}/cancel`, { @@ -65,7 +71,7 @@ async function cancelTraining(training_id) { /** * List all trainings * - * @returns {Promise} - Resolves with a page of trainings + * @returns {Promise>} - Resolves with a page of trainings */ async function listTrainings() { const response = await this.request("/trainings", { diff --git a/lib/types.js b/lib/types.js new file mode 100644 index 0000000..fd05845 --- /dev/null +++ b/lib/types.js @@ -0,0 +1,71 @@ +/** + * @typedef {"starting" | "processing" | "succeeded" | "failed" | "canceled"} Status + * @typedef {"public" | "private"} Visibility + * @typedef {"start" | "output" | "logs" | "completed"} WebhookEventType + * + * @typedef {Object} Collection + * @property {string} name + * @property {string} slug + * @property {string} description + * @property {Model[]=} models + * + * @typedef {Object} Hardware + * @property {string} sku + * @property {string} name + * + * @typedef {Object} Model + * @property {string} url + * @property {string} owner + * @property {string} name + * @property {string=} description + * @property {Visibility} visibility + * @property {string=} github_url + * @property {string=} paper_url + * @property {string=} license_url + * @property {number} run_count + * @property {string=} cover_image_url + * @property {Prediction=} default_example + * @property {ModelVersion=} latest_version + * + * @typedef {Object} ModelVersion + * @property {string} id + * @property {string} created_at + * @property {string} cog_version + * @property {string} openapi_schema + * + * @typedef {Object} Prediction + * @property {string} id + * @property {Status} status + * @property {string=} model + * @property {string} version + * @property {object} input + * @property {unknown=} output + * @property {"api" | "web"} source + * @property {unknown=} error + * @property {string=} logs + * @property {{predict_time?: number}=} metrics + * @property {string=} webhook + * @property {WebhookEventType[]=} webhook_events_filter + * @property {string} created_at + * @property {string=} started_at + * @property {string=} completed_at + * @property {{get: string; cancel: string; stream?: string}} urls + * + * @typedef {Prediction} Training + * + * @typedef {Object} ServerSentEvent + * @property {string} event + * @property {string} data + * @property {string=} id + * @property {number=} retry + */ + +/** + * @template T + * @typedef {Object} Page + * @property {string=} previous + * @property {string=} next + * @property {T[]} results + */ + +module.exports = {}; diff --git a/package.json b/package.json index 61b4c87..4af50e1 100644 --- a/package.json +++ b/package.json @@ -8,7 +8,7 @@ "license": "Apache-2.0", "main": "index.js", "type": "commonjs", - "types": "index.d.ts", + "types": "dist/types/index.d.ts", "files": [ "CONTRIBUTING.md", "LICENSE", @@ -25,12 +25,15 @@ "yarn": ">=1.7.0" }, "scripts": { + "build-types": "tsc --target ES2022 --declaration --emitDeclarationOnly --allowJs --types node --outDir ./dist/types index.js", "check": "tsc", "format": "biome format . --write", "lint-biome": "biome lint .", "lint-publint": "publint", "lint": "npm run lint-biome && npm run lint-publint", - "test": "jest" + "test": "jest", + "test-integration": "npm run build-types; for x in commonjs esm typescript; do npm --prefix integration/$x install --omit=dev && npm --prefix integration/$x test; done;", + "test-all": "npm run check; npm run test; npm run test-integration" }, "optionalDependencies": { "readable-stream": ">=4.0.0" diff --git a/tsconfig.json b/tsconfig.json index 7a564ee..c7961ea 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -2,9 +2,13 @@ "compilerOptions": { "esModuleInterop": true, "noEmit": true, - "strict": true + "strict": true, + "allowJs": true, }, + "types": ["node"], "exclude": [ + "dist", + "integration", "**/node_modules" ] }