Skip to content

Commit

Permalink
Generate TypeScript definitions from source
Browse files Browse the repository at this point in the history
  • Loading branch information
aron committed Jan 23, 2024
1 parent 6dd5b21 commit cac159f
Show file tree
Hide file tree
Showing 19 changed files with 323 additions and 284 deletions.
226 changes: 0 additions & 226 deletions index.d.ts

This file was deleted.

70 changes: 60 additions & 10 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,59 @@ class Replicate {
/**
* Create a new Replicate API client instance.
*
* @param {object} options - Configuration options for the client
* @param {string} options.auth - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable.
* @param {string} options.userAgent - Identifier of your app
* @example
* // Create a new Replicate API client instance
* const Replicate = require("replicate");
* const replicate = new Replicate({
* // get your token from https://replicate.com/account
* auth: process.env.REPLICATE_API_TOKEN,
* userAgent: "my-app/1.2.3"
* });
*
* // Run a model and await the result:
* const model = 'owner/model:version-id'
* const input = {text: 'Hello, world!'}
* const output = await replicate.run(model, { input });
*
* @param {Object} [options] - Configuration options for the client
* @param {string} [options.auth] - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable.
* @param {string} [options.userAgent] - Identifier of your app
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
*/
constructor(options = {}) {
/** @type {string} */
this.auth = options.auth || process.env.REPLICATE_API_TOKEN;

/** @type {string} */
this.userAgent =
options.userAgent || `replicate-javascript/${packageJSON.version}`;

/** @type {string} */
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";

/** @type {fetch} */
this.fetch = options.fetch || globalThis.fetch;

/** @type {collections} */
this.collections = {
list: collections.list.bind(this),
get: collections.get.bind(this),
};

/** @type {deployments} */
this.deployments = {
predictions: {
create: deployments.predictions.create.bind(this),
},
};

/** @type {hardware} */
this.hardware = {
list: hardware.list.bind(this),
};

/** @type {models} */
this.models = {
get: models.get.bind(this),
list: models.list.bind(this),
Expand All @@ -72,13 +97,15 @@ class Replicate {
},
};

/** @type {predictions} */
this.predictions = {
create: predictions.create.bind(this),
get: predictions.get.bind(this),
cancel: predictions.cancel.bind(this),
list: predictions.list.bind(this),
};

/** @type {trainings} */
this.trainings = {
create: trainings.create.bind(this),
get: trainings.get.bind(this),
Expand All @@ -90,18 +117,18 @@ class Replicate {
/**
* Run a model and wait for its output.
*
* @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
* @param {`${string}/${string}` | `${string}/${string}:${string}`} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {object} [options.wait] - Options for waiting for the prediction to finish
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
* @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed.
* @throws {Error} If the reference is invalid
* @throws {Error} If the prediction failed
* @returns {Promise<object>} - Resolves with the output of running the model
* @returns {Promise<Prediction>} - Resolves with the output of running the model
*/
async run(ref, options, progress) {
const { wait, ...data } = options;
Expand Down Expand Up @@ -237,7 +264,7 @@ class Replicate {
/**
* Stream a model and wait for its output.
*
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
* @param {string} ref - Required. The model version identifier in the format "{owner}/{name}:{version}"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
Expand Down Expand Up @@ -285,8 +312,10 @@ class Replicate {
* for await (const page of replicate.paginate(replicate.predictions.list) {
* console.log(page);
* }
* @param {Function} endpoint - Function that returns a promise for the next page of results
* @yields {object[]} Each page of results
* @template T
* @param {() => Promise<Page<T>>} endpoint - Function that returns a promise for the next page of results
* @yields {T[]} Each page of results
* @returns {AsyncGenerator<T[], void, unknown>}
*/
async *paginate(endpoint) {
const response = await endpoint();
Expand All @@ -312,7 +341,7 @@ class Replicate {
* @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling.
* @throws {Error} If the prediction doesn't complete within the maximum number of attempts
* @throws {Error} If the prediction failed
* @returns {Promise<object>} Resolves with the completed prediction object
* @returns {Promise<Prediction>} Resolves with the completed prediction object
*/
async wait(prediction, options, stop) {
const { id } = prediction;
Expand Down Expand Up @@ -359,3 +388,24 @@ class Replicate {
}

module.exports = Replicate;

// - Type Definitions

/**
* @typedef {import("./lib/error")} ApiError
* @typedef {import("./lib/types").Collection} Collection
* @typedef {import("./lib/types").ModelVersion} ModelVersion
* @typedef {import("./lib/types").Hardware} Hardware
* @typedef {import("./lib/types").Model} Model
* @typedef {import("./lib/types").Prediction} Prediction
* @typedef {import("./lib/types").Training} Training
* @typedef {import("./lib/types").ServerSentEvent} ServerSentEvent
* @typedef {import("./lib/types").Status} Status
* @typedef {import("./lib/types").Visibility} Visibility
* @typedef {import("./lib/types").WebhookEventType} WebhookEventType
*/

/**
* @template T
* @typedef {import("./lib/types").Page<T>} Page
*/
3 changes: 1 addition & 2 deletions index.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { expect, jest, test } from "@jest/globals";
import Replicate, { ApiError, Model, Prediction } from "replicate";
import Replicate, { ApiError, Model, Prediction } from "./";
import nock from "nock";
import fetch from "cross-fetch";

Expand Down Expand Up @@ -838,7 +838,6 @@ describe("Replicate client", () => {
});

test("Calls the correct API routes for a model", async () => {
const firstPollingRequest = true;

nock(BASE_URL)
.post("/models/replicate/hello-world/predictions")
Expand Down
Loading

0 comments on commit cac159f

Please sign in to comment.