diff --git a/README.md b/README.md index bd7e43d..6ec5694 100644 --- a/README.md +++ b/README.md @@ -20,20 +20,18 @@ npm install replicate ## Usage -Create the client: +Set your `REPLICATE_API_TOKEN` in your environment: -```js -import Replicate from "replicate"; - -const replicate = new Replicate({ - // get your token from https://replicate.com/account - auth: "my api token", // defaults to process.env.REPLICATE_API_TOKEN -}); +```sh +# get your token from https://replicate.com/account +export REPLICATE_API_TOKEN="r8_123..." ``` Run a model and await the result: ```js +import replicate from "replicate"; + const model = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478"; const input = { prompt: "a 19th century portrait of a raccoon gentleman wearing a suit", @@ -94,8 +92,12 @@ const output = await replicate.run(model, { input }); ### Constructor +You can create a custom instance of the Replicate client using the `replicate.Replicate` constructor: + ```js -const replicate = new Replicate(options); +import replicate from "replicate"; + +const replicate = new replicate.Replicate(options); ``` | name | type | description | @@ -121,10 +123,10 @@ you can install a fetch function from an external package like and pass it to the `fetch` option in the constructor. ```js -import Replicate from "replicate"; +import replicate from "replicate"; import fetch from "cross-fetch"; -const replicate = new Replicate({ fetch }); +const replicate = new replicate.Replicate({ fetch }); ``` You can also use the `fetch` option to add custom behavior to client requests, @@ -778,4 +780,50 @@ You can call this method directly to make other requests to the API. ## TypeScript -The `Replicate` constructor and all `replicate.*` methods are fully typed. +The `Replicate` constructor and all `replicate.*` methods are fully typed. Types are accessible +via the named exports: + +```ts +import type { Model, Prediction } from "replicate"; +``` + +## Deprecated Constructor + +Earlier versions of the Replicate library exported the `Replicate` constructor as the default +export. This will be removed in a future version, to migrate please update your code to use +the following pattern: + +If you don't need to customize your Replicate client you can just remove the constructor code +entirely: + +```js +// Deprecated +import Replicate from "replicate"; + +const replicate = new Replicate(); + +replicate.run(...); + +// Fixed +import replicate from "replicate"; + +replicate.run(...); +``` + +If you need the Replicate construtor it's available on the `replicate` object. + +```js +// Deprecated +import Replicate from "replicate"; + +const replicate = new Replicate({auth: "my-token"}); + +replicate.run(...); + +// Fixed +import replicate from "replicate"; + +replicate = new replicate.Replicate({auth: "my-token"}); + +replicate.run(...); +``` diff --git a/index.d.ts b/index.d.ts index 5620f3b..4ea4868 100644 --- a/index.d.ts +++ b/index.d.ts @@ -1,7 +1,7 @@ declare module "replicate" { - type Status = "starting" | "processing" | "succeeded" | "failed" | "canceled"; - type Visibility = "public" | "private"; - type WebhookEventType = "start" | "output" | "logs" | "completed"; + export type Status = "starting" | "processing" | "succeeded" | "failed" | "canceled"; + export type Visibility = "public" | "private"; + export type WebhookEventType = "start" | "output" | "logs" | "completed"; export interface ApiError extends Error { request: Request; @@ -82,7 +82,7 @@ declare module "replicate" { retry?: number; } - export default class Replicate { + export class Replicate { constructor(options?: { auth?: string; userAgent?: string; @@ -223,4 +223,74 @@ declare module "replicate" { list(): Promise>; }; } + + /** @deprecated */ + class DeprecatedReplicate extends Replicate { + /** @deprecated Use `const Replicate = require("replicate").Replicate` instead */ + constructor(...args: ConstructorParameters); + } + + + /** + * Default instance of the Replicate class that gets the access token + * from the REPLICATE_API_TOKEN environment variable. + * + * Create a new Replicate API client instance. + * + * @example + * + * import replicate from "replicate"; + * + * // Run a model and await the result: + * const model = 'owner/model:version-id' + * const input = {text: 'Hello, world!'} + * const output = await replicate.run(model, { input }); + * + * @remarks + * + * NOTE: Use of this object as a constructor is deprecated and will + * be removed in a future version. Import the Replicate constructor + * instead: + * + * ``` + * const Replicate = require("replicate").Replicate; + * ``` + * + * Or using esm: + * + * ``` + * import replicate from "replicate"; + * const client = new replicate.Replicate({...}); + * ``` + * + * @type { Replicate & typeof DeprecatedReplicate & {ApiError: ApiError, Replicate: Replicate} } + */ + const replicate: Replicate & typeof DeprecatedReplicate & { + /** + * Create a new Replicate API client instance. + * + * @example + * // Create a new Replicate API client instance + * const Replicate = require("replicate"); + * const replicate = new Replicate({ + * // get your token from https://replicate.com/account + * auth: process.env.REPLICATE_API_TOKEN, + * userAgent: "my-app/1.2.3" + * }); + * + * // Run a model and await the result: + * const model = 'owner/model:version-id' + * const input = {text: 'Hello, world!'} + * const output = await replicate.run(model, { input }); + * + * @param {object} options - Configuration options for the client + * @param {string} options.auth - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable. + * @param {string} options.userAgent - Identifier of your app + * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 + * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` + */ + Replicate: typeof Replicate + }; + + export default replicate; } diff --git a/index.js b/index.js index ce407f9..08bd8d0 100644 --- a/index.js +++ b/index.js @@ -1,361 +1,72 @@ +const Replicate = require("./lib/replicate"); const ApiError = require("./lib/error"); -const ModelVersionIdentifier = require("./lib/identifier"); -const { Stream } = require("./lib/stream"); -const { withAutomaticRetries } = require("./lib/util"); -const collections = require("./lib/collections"); -const deployments = require("./lib/deployments"); -const hardware = require("./lib/hardware"); -const models = require("./lib/models"); -const predictions = require("./lib/predictions"); -const trainings = require("./lib/trainings"); +/** + * Placeholder class used to warn of deprecated constructor. + * @deprecated use exported Replicate class instead + */ +class DeprecatedReplicate extends Replicate { + /** @deprecated Use `import { Replicate } from "replicate";` instead */ + // biome-ignore lint/complexity/noUselessConstructor: exists for the tsdoc comment + constructor(...args) { + super(...args); + } +} -const packageJSON = require("./package.json"); +const named = { ApiError, Replicate }; +const singleton = new Replicate(); /** - * Replicate API client library + * Default instance of the Replicate class that gets the access token + * from the REPLICATE_API_TOKEN environment variable. + * + * Create a new Replicate API client instance. * - * @see https://replicate.com/docs/reference/http * @example - * // Create a new Replicate API client instance - * const Replicate = require("replicate"); - * const replicate = new Replicate({ - * // get your token from https://replicate.com/account - * auth: process.env.REPLICATE_API_TOKEN, - * userAgent: "my-app/1.2.3" - * }); + * + * import replicate from "replicate"; * * // Run a model and await the result: * const model = 'owner/model:version-id' * const input = {text: 'Hello, world!'} * const output = await replicate.run(model, { input }); + * + * @remarks + * + * NOTE: Use of this object as a constructor is deprecated and will + * be removed in a future version. Import the Replicate constructor + * instead: + * + * ``` + * const Replicate = require("replicate").Replicate; + * ``` + * + * Or in commonjs: + * + * ``` + * import { Replicate } from "replicate"; + * const client = new Replicate({...}); + * ``` + * + * @type { Replicate & typeof DeprecatedReplicate & {ApiError: ApiError, Replicate: Replicate} } */ -class Replicate { - /** - * Create a new Replicate API client instance. - * - * @param {object} options - Configuration options for the client - * @param {string} options.auth - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable. - * @param {string} options.userAgent - Identifier of your app - * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 - * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` - */ - constructor(options = {}) { - this.auth = options.auth || process.env.REPLICATE_API_TOKEN; - this.userAgent = - options.userAgent || `replicate-javascript/${packageJSON.version}`; - this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; - this.fetch = options.fetch || globalThis.fetch; - - this.collections = { - list: collections.list.bind(this), - get: collections.get.bind(this), - }; - - this.deployments = { - predictions: { - create: deployments.predictions.create.bind(this), - }, - }; - - this.hardware = { - list: hardware.list.bind(this), - }; - - this.models = { - get: models.get.bind(this), - list: models.list.bind(this), - create: models.create.bind(this), - versions: { - list: models.versions.list.bind(this), - get: models.versions.get.bind(this), - }, - }; - - this.predictions = { - create: predictions.create.bind(this), - get: predictions.get.bind(this), - cancel: predictions.cancel.bind(this), - list: predictions.list.bind(this), - }; - - this.trainings = { - create: trainings.create.bind(this), - get: trainings.get.bind(this), - cancel: trainings.cancel.bind(this), - list: trainings.list.bind(this), - }; - } - - /** - * Run a model and wait for its output. - * - * @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version" - * @param {object} options - * @param {object} options.input - Required. An object with the model inputs - * @param {object} [options.wait] - Options for waiting for the prediction to finish - * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500 - * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction - * @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed. - * @throws {Error} If the reference is invalid - * @throws {Error} If the prediction failed - * @returns {Promise} - Resolves with the output of running the model - */ - async run(ref, options, progress) { - const { wait, ...data } = options; - - const identifier = ModelVersionIdentifier.parse(ref); - - let prediction; - if (identifier.version) { - prediction = await this.predictions.create({ - ...data, - version: identifier.version, - }); - } else if (identifier.owner && identifier.name) { - prediction = await this.predictions.create({ - ...data, - model: `${identifier.owner}/${identifier.name}`, - }); - } else { - throw new Error("Invalid model version identifier"); - } - - // Call progress callback with the initial prediction object - if (progress) { - progress(prediction); - } - - const { signal } = options; - - prediction = await this.wait( - prediction, - wait || {}, - async (updatedPrediction) => { - // Call progress callback with the updated prediction object - if (progress) { - progress(updatedPrediction); - } - - if (signal && signal.aborted) { - await this.predictions.cancel(updatedPrediction.id); - return true; // stop polling - } - - return false; // continue polling - } - ); - - // Call progress callback with the completed prediction object - if (progress) { - progress(prediction); - } - - if (prediction.status === "failed") { - throw new Error(`Prediction failed: ${prediction.error}`); - } - - return prediction.output; +const replicate = new Proxy(DeprecatedReplicate, { + get(target, prop, receiver) { + // Should mostly behave like the singleton. + if (named[prop]) { + return named[prop]; + } + // Provide Replicate & ApiError constructors. + if (singleton[prop]) { + return singleton[prop]; + } + // Fallback to Replicate constructor properties. + return Reflect.get(target, prop, receiver); + }, + set(_target, prop, newValue, _receiver) { + singleton[prop] = newValue; + return true; } +}); - /** - * Make a request to the Replicate API. - * - * @param {string} route - REST API endpoint path - * @param {object} options - Request parameters - * @param {string} [options.method] - HTTP method. Defaults to GET - * @param {object} [options.params] - Query parameters - * @param {object|Headers} [options.headers] - HTTP headers - * @param {object} [options.data] - Body parameters - * @returns {Promise} - Resolves with the response object - * @throws {ApiError} If the request failed - */ - async request(route, options) { - const { auth, baseUrl, userAgent } = this; - - let url; - if (route instanceof URL) { - url = route; - } else { - url = new URL( - route.startsWith("/") ? route.slice(1) : route, - baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/` - ); - } - - const { method = "GET", params = {}, data } = options; - - for (const [key, value] of Object.entries(params)) { - url.searchParams.append(key, value); - } - - const headers = {}; - if (auth) { - headers["Authorization"] = `Token ${auth}`; - } - headers["Content-Type"] = "application/json"; - headers["User-Agent"] = userAgent; - if (options.headers) { - for (const [key, value] of Object.entries(options.headers)) { - headers[key] = value; - } - } - - const init = { - method, - headers, - body: data ? JSON.stringify(data) : undefined, - }; - - const shouldRetry = - method === "GET" - ? (response) => response.status === 429 || response.status >= 500 - : (response) => response.status === 429; - - // Workaround to fix `TypeError: Illegal invocation` error in Cloudflare Workers - // https://github.com/replicate/replicate-javascript/issues/134 - const _fetch = this.fetch; // eslint-disable-line no-underscore-dangle - const response = await withAutomaticRetries(async () => _fetch(url, init), { - shouldRetry, - }); - - if (!response.ok) { - const request = new Request(url, init); - const responseText = await response.text(); - throw new ApiError( - `Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`, - request, - response - ); - } - - return response; - } - - /** - * Stream a model and wait for its output. - * - * @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}" - * @param {object} options - * @param {object} options.input - Required. An object with the model inputs - * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction - * @throws {Error} If the prediction failed - * @yields {ServerSentEvent} Each streamed event from the prediction - */ - async *stream(ref, options) { - const { wait, ...data } = options; - - const identifier = ModelVersionIdentifier.parse(ref); - - let prediction; - if (identifier.version) { - prediction = await this.predictions.create({ - ...data, - version: identifier.version, - stream: true, - }); - } else if (identifier.owner && identifier.name) { - prediction = await this.predictions.create({ - ...data, - model: `${identifier.owner}/${identifier.name}`, - stream: true, - }); - } else { - throw new Error("Invalid model version identifier"); - } - - if (prediction.urls && prediction.urls.stream) { - const { signal } = options; - const stream = new Stream(prediction.urls.stream, { signal }); - yield* stream; - } else { - throw new Error("Prediction does not support streaming"); - } - } - - /** - * Paginate through a list of results. - * - * @generator - * @example - * for await (const page of replicate.paginate(replicate.predictions.list) { - * console.log(page); - * } - * @param {Function} endpoint - Function that returns a promise for the next page of results - * @yields {object[]} Each page of results - */ - async *paginate(endpoint) { - const response = await endpoint(); - yield response.results; - if (response.next) { - const nextPage = () => - this.request(response.next, { method: "GET" }).then((r) => r.json()); - yield* this.paginate(nextPage); - } - } - - /** - * Wait for a prediction to finish. - * - * If the prediction has already finished, - * this function returns immediately. - * Otherwise, it polls the API until the prediction finishes. - * - * @async - * @param {object} prediction - Prediction object - * @param {object} options - Options - * @param {number} [options.interval] - Polling interval in milliseconds. Defaults to 500 - * @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling. - * @throws {Error} If the prediction doesn't complete within the maximum number of attempts - * @throws {Error} If the prediction failed - * @returns {Promise} Resolves with the completed prediction object - */ - async wait(prediction, options, stop) { - const { id } = prediction; - if (!id) { - throw new Error("Invalid prediction"); - } - - if ( - prediction.status === "succeeded" || - prediction.status === "failed" || - prediction.status === "canceled" - ) { - return prediction; - } - - // eslint-disable-next-line no-promise-executor-return - const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); - - const interval = (options && options.interval) || 500; - - let updatedPrediction = await this.predictions.get(id); - - while ( - updatedPrediction.status !== "succeeded" && - updatedPrediction.status !== "failed" && - updatedPrediction.status !== "canceled" - ) { - /* eslint-disable no-await-in-loop */ - if (stop && (await stop(updatedPrediction)) === true) { - break; - } - - await sleep(interval); - updatedPrediction = await this.predictions.get(prediction.id); - /* eslint-enable no-await-in-loop */ - } - - if (updatedPrediction.status === "failed") { - throw new Error(`Prediction failed: ${updatedPrediction.error}`); - } - - return updatedPrediction; - } -} - -module.exports = Replicate; +module.exports = replicate; diff --git a/index.test.ts b/index.test.ts index 5b5a1dd..b6c38a8 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,41 +1,42 @@ import { expect, jest, test } from "@jest/globals"; -import Replicate, { ApiError, Model, Prediction } from "replicate"; +import replicate, { ApiError, Model, Prediction, Replicate } from "replicate"; import nock from "nock"; import fetch from "cross-fetch"; +import assert from "node:assert"; + +assert(process.env.REPLICATE_API_TOKEN === "test-token", `set REPLICATE_API_TOKEN to "test-token"`) -let client: Replicate; const BASE_URL = "https://api.replicate.com/v1"; nock.disableNetConnect(); -describe("Replicate client", () => { - let unmatched: any[] = []; - const handleNoMatch = (req: unknown, options: any, body: string) => - unmatched.push({ req, options, body }); - - beforeEach(() => { - client = new Replicate({ auth: "test-token" }); - client.fetch = fetch; - - unmatched = []; - nock.emitter.on("no match", handleNoMatch); - }); +describe(`const replicate = require("replicate");`, () => { + testInstance(() => { + replicate.fetch = fetch; + return replicate; + }) +}); - afterEach(() => { - nock.emitter.off("no match", handleNoMatch); - expect(unmatched).toStrictEqual([]); +describe(`const Replicate = require("replicate"); (deprecated)`, () => { + testConstructor((opts) => new replicate({ auth: "test-token", fetch, ...opts })) + testInstance((opts) => new replicate({ auth: "test-token", fetch, ...opts })) +}); - nock.abortPendingRequests(); - nock.cleanAll(); - }); +describe(`const Replicate = require("replicate").Replicate;`, () => { + testConstructor((opts) => new replicate.Replicate({ auth: "test-token", fetch, ...opts })) + testInstance((opts) => new replicate.Replicate({ auth: "test-token", fetch, ...opts })) +}); +/** Test suite to exercise the Replicate constructor */ +function testConstructor(createClient: (opts?: object) => Replicate) { describe("constructor", () => { test("Sets default baseUrl", () => { + const client = createClient(); expect(client.baseUrl).toBe("https://api.replicate.com/v1"); }); test("Sets custom baseUrl", () => { - const clientWithCustomBaseUrl = new Replicate({ + const clientWithCustomBaseUrl = createClient({ baseUrl: "https://example.com/", auth: "test-token", }); @@ -43,7 +44,7 @@ describe("Replicate client", () => { }); test("Sets custom userAgent", () => { - const clientWithCustomUserAgent = new Replicate({ + const clientWithCustomUserAgent = createClient({ userAgent: "my-app/1.2.3", auth: "test-token", }); @@ -54,7 +55,7 @@ describe("Replicate client", () => { process.env.REPLICATE_API_TOKEN = "test-token"; expect(() => { - const clientWithImplicitAuth = new Replicate(); + const clientWithImplicitAuth = createClient(); expect(clientWithImplicitAuth.auth).toBe("test-token"); }).not.toThrow(); @@ -62,11 +63,32 @@ describe("Replicate client", () => { test("Does not throw error if blank auth token is provided", () => { expect(() => { - new Replicate({ auth: "" }); + createClient({ auth: "" }); }).not.toThrow(); }); }); + } + +/** Test suite to exercise the Replicate instance */ +function testInstance(createClient: (opts?: object) => Replicate) { + let unmatched: any[] = []; + const handleNoMatch = (req: unknown, options: any, body: string) => + unmatched.push({ req, options, body }); + + beforeEach(() => { + unmatched = []; + nock.emitter.on("no match", handleNoMatch); + }); + + afterEach(() => { + nock.emitter.off("no match", handleNoMatch); + expect(unmatched).toStrictEqual([]); + + nock.abortPendingRequests(); + nock.cleanAll(); + }); + describe("collections.list", () => { test("Calls the correct API route", async () => { nock(BASE_URL) @@ -89,6 +111,7 @@ describe("Replicate client", () => { previous: null, }); + const client = createClient(); const collections = await client.collections.list(); expect(collections.results.length).toBe(2); }); @@ -105,6 +128,7 @@ describe("Replicate client", () => { models: [], }); + const client = createClient(); const collection = await client.collections.get("super-resolution"); expect(collection.name).toBe("Super resolution"); }); @@ -128,6 +152,7 @@ describe("Replicate client", () => { latest_version: {}, }); + const client = createClient(); await client.models.get("replicate", "hello-world"); }); // Add more tests for error handling, edge cases, etc. @@ -150,6 +175,7 @@ describe("Replicate client", () => { }); const results: Model[] = []; + const client = createClient(); for await (const batch of client.paginate(client.models.list)) { results.push(...batch); } @@ -188,6 +214,7 @@ describe("Replicate client", () => { logs: null, metrics: {}, }); + const client = createClient(); const prediction = await client.predictions.create({ version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", @@ -208,6 +235,7 @@ describe("Replicate client", () => { return body; }); + const client = createClient(); await client.predictions.create({ version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", @@ -220,6 +248,7 @@ describe("Replicate client", () => { test("Throws an error if webhook URL is invalid", async () => { await expect(async () => { + const client = createClient(); await client.predictions.create({ version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", @@ -244,6 +273,7 @@ describe("Replicate client", () => { try { expect.hasAssertions(); + const client = createClient(); await client.predictions.create({ version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", @@ -271,6 +301,7 @@ describe("Replicate client", () => { .reply(201, { id: "ufawqhfynnddngldkgtslldrkq", }); + const client = createClient(); const prediction = await client.predictions.create({ version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", @@ -290,6 +321,7 @@ describe("Replicate client", () => { { "Content-Type": "application/json" } ); + const client = createClient(); await expect( client.predictions.create({ version: @@ -335,6 +367,7 @@ describe("Replicate client", () => { predict_time: 4.484541, }, }); + const client = createClient(); const prediction = await client.predictions.get( "rrr4z55ocneqzikepnug6xezpe" ); @@ -356,6 +389,7 @@ describe("Replicate client", () => { id: "rrr4z55ocneqzikepnug6xezpe", }); + const client = createClient(); const prediction = await client.predictions.get( "rrr4z55ocneqzikepnug6xezpe" ); @@ -377,6 +411,7 @@ describe("Replicate client", () => { id: "rrr4z55ocneqzikepnug6xezpe", }); + const client = createClient(); const prediction = await client.predictions.get( "rrr4z55ocneqzikepnug6xezpe" ); @@ -411,6 +446,7 @@ describe("Replicate client", () => { metrics: {}, }); + const client = createClient(); const prediction = await client.predictions.cancel( "ufawqhfynnddngldkgtslldrkq" ); @@ -447,6 +483,7 @@ describe("Replicate client", () => { ], }); + const client = createClient(); const predictions = await client.predictions.list(); expect(predictions.results.length).toBe(1); expect(predictions.results[0].id).toBe("jpzd7hm5gfcapbfyt4mqytarku"); @@ -468,6 +505,7 @@ describe("Replicate client", () => { }); const results: Prediction[] = []; + const client = createClient(); for await (const batch of client.paginate(client.predictions.list)) { results.push(...batch); } @@ -502,6 +540,7 @@ describe("Replicate client", () => { completed_at: null, }); + const client = createClient(); const training = await client.trainings.create( "owner", "model", @@ -517,6 +556,7 @@ describe("Replicate client", () => { }); test("Throws an error if webhook is not a valid URL", async () => { + const client = createClient(); await expect( client.trainings.create( "owner", @@ -559,6 +599,7 @@ describe("Replicate client", () => { completed_at: null, }); + const client = createClient(); const training = await client.trainings.get("zz4ibbonubfz7carwiefibzgga"); expect(training.status).toBe("succeeded"); }); @@ -589,6 +630,7 @@ describe("Replicate client", () => { completed_at: null, }); + const client = createClient(); const training = await client.trainings.cancel( "zz4ibbonubfz7carwiefibzgga" ); @@ -625,6 +667,7 @@ describe("Replicate client", () => { ], }); + const client = createClient(); const trainings = await client.trainings.list(); expect(trainings.results.length).toBe(1); expect(trainings.results[0].id).toBe("jpzd7hm5gfcapbfyt4mqytarku"); @@ -646,6 +689,7 @@ describe("Replicate client", () => { }); const results: Prediction[] = []; + const client = createClient(); for await (const batch of client.paginate(client.trainings.list)) { results.push(...batch); } @@ -684,6 +728,7 @@ describe("Replicate client", () => { logs: null, metrics: {}, }); + const client = createClient(); const prediction = await client.deployments.predictions.create( "replicate", "greeter", @@ -721,6 +766,7 @@ describe("Replicate client", () => { get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci", }, }); + const client = createClient(); const prediction = await client.predictions.create({ model: "meta/llama-2-70b-chat", input: { @@ -745,6 +791,7 @@ describe("Replicate client", () => { { name: "Nvidia A40 (Large) GPU", sku: "gpu-a40-large" }, ]); + const client = createClient(); const hardware = await client.hardware.list(); expect(hardware.length).toBe(4); expect(hardware[0].name).toBe("CPU"); @@ -763,6 +810,7 @@ describe("Replicate client", () => { description: "A test model", }); + const client = createClient(); const model = await client.models.create("test-owner", "test-model", { visibility: "public", hardware: "cpu", @@ -779,8 +827,6 @@ describe("Replicate client", () => { describe("run", () => { test("Calls the correct API routes for a version", async () => { - const firstPollingRequest = true; - nock(BASE_URL) .post("/predictions") .reply(201, { @@ -802,6 +848,7 @@ describe("Replicate client", () => { const progress = jest.fn(); + const client = createClient(); const output = await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { @@ -861,6 +908,7 @@ describe("Replicate client", () => { const progress = jest.fn(); + const client = createClient(); const output = await client.run( "replicate/hello-world", { @@ -910,12 +958,14 @@ describe("Replicate client", () => { output: "foobar", }); + const client = createClient(); await expect( client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } }) ).resolves.not.toThrow(); }); test("Throws an error for invalid identifiers", async () => { + const client = createClient(); const options = { input: { text: "Hello, world!" } }; // @ts-expect-error @@ -928,6 +978,7 @@ describe("Replicate client", () => { }); test("Throws an error if webhook URL is invalid", async () => { + const client = createClient(); await expect(async () => { await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", @@ -966,6 +1017,7 @@ describe("Replicate client", () => { status: "canceled", }); + const client = createClient(); await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { @@ -981,4 +1033,5 @@ describe("Replicate client", () => { }); // Continue with tests for other methods -}); +} + diff --git a/integration/commonjs/constructor.test.js b/integration/commonjs/constructor.test.js new file mode 100644 index 0000000..79048b0 --- /dev/null +++ b/integration/commonjs/constructor.test.js @@ -0,0 +1,21 @@ +const { test } = require('node:test'); +const assert = require('node:assert'); +const Replicate = require('replicate').Replicate; + +const replicate = new Replicate(); + +async function main() { + return await replicate.run( + "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { + text: "Claire CommonJS" + } + } + ); +}; + +test('main', async () => { + const output = await main(); + assert.equal(output, "hello Claire CommonJS"); +}); diff --git a/integration/commonjs/deprecated.test.js b/integration/commonjs/deprecated.test.js new file mode 100644 index 0000000..49b5579 --- /dev/null +++ b/integration/commonjs/deprecated.test.js @@ -0,0 +1,21 @@ +const { test } = require('node:test'); +const assert = require('node:assert'); +const Replicate = require('replicate'); + +const replicate = new Replicate(); + +async function main() { + return await replicate.run( + "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { + text: "Claire CommonJS" + } + } + ); +}; + +test('main', async () => { + const output = await main(); + assert.equal(output, "hello Claire CommonJS"); +}); diff --git a/integration/commonjs/index.test.js b/integration/commonjs/index.test.js deleted file mode 100644 index 5ef7b63..0000000 --- a/integration/commonjs/index.test.js +++ /dev/null @@ -1,8 +0,0 @@ -const { test } = require('node:test'); -const assert = require('node:assert'); -const main = require('./index'); - -test('main', async () => { - const output = await main(); - assert.equal(output, "hello Claire CommonJS"); -}); diff --git a/integration/commonjs/package.json b/integration/commonjs/package.json index 7fb6fc8..d6e6e47 100644 --- a/integration/commonjs/package.json +++ b/integration/commonjs/package.json @@ -5,7 +5,7 @@ "description": "CommonJS integration tests", "main": "index.js", "scripts": { - "test": "node --test ./index.test.js" + "test": "node --test ./*.test.js" }, "dependencies": { "replicate": "file:../../" diff --git a/integration/commonjs/singleton.test.js b/integration/commonjs/singleton.test.js new file mode 100644 index 0000000..1211080 --- /dev/null +++ b/integration/commonjs/singleton.test.js @@ -0,0 +1,19 @@ +const { test } = require('node:test'); +const assert = require('node:assert'); +const replicate = require('replicate'); + +async function main() { + return await replicate.run( + "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { + text: "Claire CommonJS" + } + } + ); +}; + +test('main', async () => { + const output = await main(); + assert.equal(output, "hello Claire CommonJS"); +}); diff --git a/integration/esm/constructor.test.js b/integration/esm/constructor.test.js new file mode 100644 index 0000000..206342f --- /dev/null +++ b/integration/esm/constructor.test.js @@ -0,0 +1,21 @@ +import { test } from 'node:test'; +import assert from 'node:assert'; +import replicate_ from "replicate"; + +const replicate = new replicate_.Replicate(); + +async function main() { + return await replicate.run( + "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { + text: "Evelyn ESM" + } + } + ); +}; + +test('main', async () => { + const output = await main(); + assert.equal(output, "hello Evelyn ESM"); +}); diff --git a/integration/esm/index.js b/integration/esm/deprecated.test.js similarity index 50% rename from integration/esm/index.js rename to integration/esm/deprecated.test.js index 547b726..0a5c7ee 100644 --- a/integration/esm/index.js +++ b/integration/esm/deprecated.test.js @@ -1,10 +1,10 @@ +import { test } from 'node:test'; +import assert from 'node:assert'; import Replicate from "replicate"; -const replicate = new Replicate({ - auth: process.env.REPLICATE_API_TOKEN, -}); +const replicate = new Replicate(); -export default async function main() { +async function main() { return await replicate.run( "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { @@ -14,3 +14,8 @@ export default async function main() { } ); }; + +test('main', async () => { + const output = await main(); + assert.equal(output, "hello Evelyn ESM"); +}); diff --git a/integration/esm/index.test.js b/integration/esm/index.test.js deleted file mode 100644 index 2bd276f..0000000 --- a/integration/esm/index.test.js +++ /dev/null @@ -1,8 +0,0 @@ -import { test } from 'node:test'; -import assert from 'node:assert'; -import main from './index.js'; - -test('main', async () => { - const output = await main(); - assert.equal(output, "hello Evelyn ESM"); -}); diff --git a/integration/esm/package.json b/integration/esm/package.json index 51076d7..ca093ac 100644 --- a/integration/esm/package.json +++ b/integration/esm/package.json @@ -6,7 +6,7 @@ "main": "index.js", "type": "module", "scripts": { - "test": "node --test ./index.test.js" + "test": "node --test ./*.test.js" }, "dependencies": { "replicate": "file:../../" diff --git a/integration/esm/singleton.test.js b/integration/esm/singleton.test.js new file mode 100644 index 0000000..59860ad --- /dev/null +++ b/integration/esm/singleton.test.js @@ -0,0 +1,19 @@ +import { test } from 'node:test'; +import assert from 'node:assert'; +import replicate from "replicate"; + +async function main() { + return await replicate.run( + "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { + text: "Evelyn ESM" + } + } + ); +}; + +test('main', async () => { + const output = await main(); + assert.equal(output, "hello Evelyn ESM"); +}); diff --git a/integration/typescript/constructor.test.ts b/integration/typescript/constructor.test.ts new file mode 100644 index 0000000..b8f2b4d --- /dev/null +++ b/integration/typescript/constructor.test.ts @@ -0,0 +1,21 @@ +import { test } from 'node:test'; +import assert from 'node:assert'; +import replicate_ from "replicate"; + +const replicate = new replicate_.Replicate(); + +async function main() { + return await replicate.run( + "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { + text: "Tracy TypeScript" + } + } + ); +}; + +test('main', async () => { + const output = await main(); + assert.equal(output, "hello Tracy TypeScript"); +}); diff --git a/integration/typescript/index.ts b/integration/typescript/deprecated.test.ts similarity index 50% rename from integration/typescript/index.ts rename to integration/typescript/deprecated.test.ts index 8e27a3b..c3f85b1 100644 --- a/integration/typescript/index.ts +++ b/integration/typescript/deprecated.test.ts @@ -1,10 +1,10 @@ +import { test } from 'node:test'; +import assert from 'node:assert'; import Replicate from "replicate"; -const replicate = new Replicate({ - auth: process.env.REPLICATE_API_TOKEN, -}); +const replicate = new Replicate(); -export default async function main() { +async function main() { return await replicate.run( "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { @@ -14,3 +14,8 @@ export default async function main() { } ); }; + +test('main', async () => { + const output = await main(); + assert.equal(output, "hello Tracy TypeScript"); +}); diff --git a/integration/typescript/index.test.ts b/integration/typescript/index.test.ts deleted file mode 100644 index be4ab90..0000000 --- a/integration/typescript/index.test.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { test } from 'node:test'; -import assert from 'node:assert'; -import main from './index.js'; - -// Verify exported types. -import type { - Status, - Visibility, - WebhookEventType, - ApiError, - Collection, - Hardware, - Model, - ModelVersion, - Prediction, - Training, - Page, - ServerSentEvent, -} from "replicate"; - -test('main', async () => { - const output = await main(); - assert.equal(output, "hello Tracy TypeScript"); -}); diff --git a/integration/typescript/package.json b/integration/typescript/package.json index 4adae99..8dc9a5c 100644 --- a/integration/typescript/package.json +++ b/integration/typescript/package.json @@ -6,7 +6,7 @@ "main": "index.js", "type": "module", "scripts": { - "test": "tsc && node --test ./dist/index.test.js" + "test": "tsc && node --test ./dist/*.test.js" }, "dependencies": { "@types/node": "^20.11.0", diff --git a/integration/typescript/singleton.test.ts b/integration/typescript/singleton.test.ts new file mode 100644 index 0000000..c34f8fb --- /dev/null +++ b/integration/typescript/singleton.test.ts @@ -0,0 +1,19 @@ +import { test } from 'node:test'; +import assert from 'node:assert'; +import replicate from "replicate"; + +async function main() { + return await replicate.run( + "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { + text: "Tracy TypeScript" + } + } + ); +}; + +test('main', async () => { + const output = await main(); + assert.equal(output, "hello Tracy TypeScript"); +}); diff --git a/lib/replicate.js b/lib/replicate.js new file mode 100644 index 0000000..1fac1e4 --- /dev/null +++ b/lib/replicate.js @@ -0,0 +1,374 @@ +const ApiError = require("./error"); +const ModelVersionIdentifier = require("./identifier"); +const { Stream } = require("./stream"); +const { withAutomaticRetries } = require("./util"); + +const collections = require("./collections"); +const deployments = require("./deployments"); +const hardware = require("./hardware"); +const models = require("./models"); +const predictions = require("./predictions"); +const trainings = require("./trainings"); + +const packageJSON = require("../package.json"); + +/** + * Replicate API client library + * + * @see https://replicate.com/docs/reference/http + * @example + * + * // Create a new Replicate API client instance + * const Replicate = require("replicate").Replicate; + * const replicate = new Replicate({ + * // get your token from https://replicate.com/account + * auth: process.env.REPLICATE_API_TOKEN, + * userAgent: "my-app/1.2.3" + * }); + * + * // Run a model and await the result: + * const model = 'owner/model:version-id' + * const input = {text: 'Hello, world!'} + * const output = await replicate.run(model, { input }); + */ +module.exports = class Replicate { + /** + * Create a new Replicate API client instance. + * + * @example + * // Create a new Replicate API client instance + * const Replicate = require("replicate"); + * const replicate = new Replicate({ + * // get your token from https://replicate.com/account + * auth: process.env.REPLICATE_API_TOKEN, + * userAgent: "my-app/1.2.3" + * }); + * + * // Run a model and await the result: + * const model = 'owner/model:version-id' + * const input = {text: 'Hello, world!'} + * const output = await replicate.run(model, { input }); + * + * @param {object} options - Configuration options for the client + * @param {string} options.auth - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable. + * @param {string} options.userAgent - Identifier of your app + * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 + * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` + */ + constructor(options = {}) { + this.auth = options.auth || process.env.REPLICATE_API_TOKEN; + this.userAgent = + options.userAgent || `replicate-javascript/${packageJSON.version}`; + this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; + this.fetch = options.fetch || globalThis.fetch; + + this.collections = { + list: collections.list.bind(this), + get: collections.get.bind(this), + }; + + this.deployments = { + predictions: { + create: deployments.predictions.create.bind(this), + }, + }; + + this.hardware = { + list: hardware.list.bind(this), + }; + + this.models = { + get: models.get.bind(this), + list: models.list.bind(this), + create: models.create.bind(this), + versions: { + list: models.versions.list.bind(this), + get: models.versions.get.bind(this), + }, + }; + + this.predictions = { + create: predictions.create.bind(this), + get: predictions.get.bind(this), + cancel: predictions.cancel.bind(this), + list: predictions.list.bind(this), + }; + + this.trainings = { + create: trainings.create.bind(this), + get: trainings.get.bind(this), + cancel: trainings.cancel.bind(this), + list: trainings.list.bind(this), + }; + } + + /** + * Run a model and wait for its output. + * + * @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version" + * @param {object} options + * @param {object} options.input - Required. An object with the model inputs + * @param {object} [options.wait] - Options for waiting for the prediction to finish + * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500 + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction + * @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed. + * @throws {Error} If the reference is invalid + * @throws {Error} If the prediction failed + * @returns {Promise} - Resolves with the output of running the model + */ + async run(ref, options, progress) { + const { wait, ...data } = options; + + const identifier = ModelVersionIdentifier.parse(ref); + + let prediction; + if (identifier.version) { + prediction = await this.predictions.create({ + ...data, + version: identifier.version, + }); + } else if (identifier.owner && identifier.name) { + prediction = await this.predictions.create({ + ...data, + model: `${identifier.owner}/${identifier.name}`, + }); + } else { + throw new Error("Invalid model version identifier"); + } + + // Call progress callback with the initial prediction object + if (progress) { + progress(prediction); + } + + const { signal } = options; + + prediction = await this.wait( + prediction, + wait || {}, + async (updatedPrediction) => { + // Call progress callback with the updated prediction object + if (progress) { + progress(updatedPrediction); + } + + if (signal && signal.aborted) { + await this.predictions.cancel(updatedPrediction.id); + return true; // stop polling + } + + return false; // continue polling + } + ); + + // Call progress callback with the completed prediction object + if (progress) { + progress(prediction); + } + + if (prediction.status === "failed") { + throw new Error(`Prediction failed: ${prediction.error}`); + } + + return prediction.output; + } + + /** + * Make a request to the Replicate API. + * + * @param {string} route - REST API endpoint path + * @param {object} options - Request parameters + * @param {string} [options.method] - HTTP method. Defaults to GET + * @param {object} [options.params] - Query parameters + * @param {object|Headers} [options.headers] - HTTP headers + * @param {object} [options.data] - Body parameters + * @returns {Promise} - Resolves with the response object + * @throws {ApiError} If the request failed + */ + async request(route, options) { + const { auth, baseUrl, userAgent } = this; + + let url; + if (route instanceof URL) { + url = route; + } else { + url = new URL( + route.startsWith("/") ? route.slice(1) : route, + baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/` + ); + } + + const { method = "GET", params = {}, data } = options; + + for (const [key, value] of Object.entries(params)) { + url.searchParams.append(key, value); + } + + const headers = {}; + if (auth) { + headers["Authorization"] = `Token ${auth}`; + } + headers["Content-Type"] = "application/json"; + headers["User-Agent"] = userAgent; + if (options.headers) { + for (const [key, value] of Object.entries(options.headers)) { + headers[key] = value; + } + } + + const init = { + method, + headers, + body: data ? JSON.stringify(data) : undefined, + }; + + const shouldRetry = + method === "GET" + ? (response) => response.status === 429 || response.status >= 500 + : (response) => response.status === 429; + + // Workaround to fix `TypeError: Illegal invocation` error in Cloudflare Workers + // https://github.com/replicate/replicate-javascript/issues/134 + const _fetch = this.fetch; // eslint-disable-line no-underscore-dangle + const response = await withAutomaticRetries(async () => _fetch(url, init), { + shouldRetry, + }); + + if (!response.ok) { + const request = new Request(url, init); + const responseText = await response.text(); + throw new ApiError( + `Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`, + request, + response + ); + } + + return response; + } + + /** + * Stream a model and wait for its output. + * + * @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}" + * @param {object} options + * @param {object} options.input - Required. An object with the model inputs + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction + * @throws {Error} If the prediction failed + * @yields {ServerSentEvent} Each streamed event from the prediction + */ + async *stream(ref, options) { + const { wait, ...data } = options; + + const identifier = ModelVersionIdentifier.parse(ref); + + let prediction; + if (identifier.version) { + prediction = await this.predictions.create({ + ...data, + version: identifier.version, + stream: true, + }); + } else if (identifier.owner && identifier.name) { + prediction = await this.predictions.create({ + ...data, + model: `${identifier.owner}/${identifier.name}`, + stream: true, + }); + } else { + throw new Error("Invalid model version identifier"); + } + + if (prediction.urls && prediction.urls.stream) { + const { signal } = options; + const stream = new Stream(prediction.urls.stream, { signal }); + yield* stream; + } else { + throw new Error("Prediction does not support streaming"); + } + } + + /** + * Paginate through a list of results. + * + * @generator + * @example + * for await (const page of replicate.paginate(replicate.predictions.list) { + * console.log(page); + * } + * @param {Function} endpoint - Function that returns a promise for the next page of results + * @yields {object[]} Each page of results + */ + async *paginate(endpoint) { + const response = await endpoint(); + yield response.results; + if (response.next) { + const nextPage = () => + this.request(response.next, { method: "GET" }).then((r) => r.json()); + yield* this.paginate(nextPage); + } + } + + /** + * Wait for a prediction to finish. + * + * If the prediction has already finished, + * this function returns immediately. + * Otherwise, it polls the API until the prediction finishes. + * + * @async + * @param {object} prediction - Prediction object + * @param {object} options - Options + * @param {number} [options.interval] - Polling interval in milliseconds. Defaults to 500 + * @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling. + * @throws {Error} If the prediction doesn't complete within the maximum number of attempts + * @throws {Error} If the prediction failed + * @returns {Promise} Resolves with the completed prediction object + */ + async wait(prediction, options, stop) { + const { id } = prediction; + if (!id) { + throw new Error("Invalid prediction"); + } + + if ( + prediction.status === "succeeded" || + prediction.status === "failed" || + prediction.status === "canceled" + ) { + return prediction; + } + + // eslint-disable-next-line no-promise-executor-return + const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); + + const interval = (options && options.interval) || 500; + + let updatedPrediction = await this.predictions.get(id); + + while ( + updatedPrediction.status !== "succeeded" && + updatedPrediction.status !== "failed" && + updatedPrediction.status !== "canceled" + ) { + /* eslint-disable no-await-in-loop */ + if (stop && (await stop(updatedPrediction)) === true) { + break; + } + + await sleep(interval); + updatedPrediction = await this.predictions.get(prediction.id); + /* eslint-enable no-await-in-loop */ + } + + if (updatedPrediction.status === "failed") { + throw new Error(`Prediction failed: ${updatedPrediction.error}`); + } + + return updatedPrediction; + } +}; diff --git a/package.json b/package.json index 24a088b..a0219de 100644 --- a/package.json +++ b/package.json @@ -17,7 +17,7 @@ "check": "tsc", "format": "biome format . --write", "lint": "biome lint .", - "test": "jest" + "test": "REPLICATE_API_TOKEN=test-token jest" }, "optionalDependencies": { "readable-stream": ">=4.0.0"