Skip to content

Commit

Permalink
Tighten up the typings across the board
Browse files Browse the repository at this point in the history
  • Loading branch information
aron committed Jan 16, 2024
1 parent 8d080f8 commit 1195d8a
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 49 deletions.
31 changes: 17 additions & 14 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ require("./lib/types");
* @deprecated use exported Replicate class instead
*/
class DeprecatedReplicate extends ReplicateClass {
/** @deprecated Use `import { Replicate } from "replicate";` instead */
/**
* @deprecated Use `import { Replicate } from "replicate";` instead
* @param {ConstructorParameters<typeof ReplicateClass>[0]=} options
*/
// biome-ignore lint/complexity/noUselessConstructor: exists for the tsdoc comment
constructor(...args) {
super(...args);
constructor(options) {
super(options);
}
}

Expand Down Expand Up @@ -77,20 +80,20 @@ module.exports = replicate;
/**
* @typedef {import("./lib/replicate")} Replicate
* @typedef {import("./lib/error")} ApiError
* @typedef {typeof import("./lib/types").Collection} Collection
* @typedef {typeof import("./lib/types").ModelVersion} ModelVersion
* @typedef {typeof import("./lib/types").Hardware} Hardware
* @typedef {typeof import("./lib/types").Model} Model
* @typedef {typeof import("./lib/types").Prediction} Prediction
* @typedef {typeof import("./lib/types").Training} Training
* @typedef {typeof import("./lib/types").ServerSentEvent} ServerSentEvent
* @typedef {typeof import("./lib/types").Status} Status
* @typedef {typeof import("./lib/types").Visibility} Visibility
* @typedef {typeof import("./lib/types").WebhookEventType} WebhookEventType
* @typedef {import("./lib/types").Collection} Collection
* @typedef {import("./lib/types").ModelVersion} ModelVersion
* @typedef {import("./lib/types").Hardware} Hardware
* @typedef {import("./lib/types").Model} Model
* @typedef {import("./lib/types").Prediction} Prediction
* @typedef {import("./lib/types").Training} Training
* @typedef {import("./lib/types").ServerSentEvent} ServerSentEvent
* @typedef {import("./lib/types").Status} Status
* @typedef {import("./lib/types").Visibility} Visibility
* @typedef {import("./lib/types").WebhookEventType} WebhookEventType
*/

/**
* @template T
* @typedef {typeof import("./lib/types").Page} Page
* @typedef {import("./lib/types").Page<T>} Page
*/

8 changes: 2 additions & 6 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ function testInstance(createClient: (opts?: object) => Replicate) {
});
const client = createClient();
const prediction = await client.predictions.create({
model: "foo",
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: {
Expand All @@ -238,7 +237,6 @@ function testInstance(createClient: (opts?: object) => Replicate) {

const client = createClient();
await client.predictions.create({
model: "foo",
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: {
Expand All @@ -252,7 +250,6 @@ function testInstance(createClient: (opts?: object) => Replicate) {
await expect(async () => {
const client = createClient();
await client.predictions.create({
model: "foo",
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: {
Expand All @@ -278,7 +275,6 @@ function testInstance(createClient: (opts?: object) => Replicate) {

const client = createClient();
await client.predictions.create({
model: "foo",
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: {
Expand Down Expand Up @@ -307,7 +303,6 @@ function testInstance(createClient: (opts?: object) => Replicate) {
});
const client = createClient();
const prediction = await client.predictions.create({
model: "foo",
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: {
Expand All @@ -329,7 +324,6 @@ function testInstance(createClient: (opts?: object) => Replicate) {
const client = createClient();
await expect(
client.predictions.create({
model: "foo",
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: {
Expand Down Expand Up @@ -973,10 +967,12 @@ function testInstance(createClient: (opts?: object) => Replicate) {
const client = createClient();
const options = { input: { text: "Hello, world!" } };

// @ts-expect-error
await expect(client.run("owner:abc123", options)).rejects.toThrow();

await expect(client.run("/model:abc123", options)).rejects.toThrow();

// @ts-expect-error
await expect(client.run(":abc123", options)).rejects.toThrow();
});

Expand Down
2 changes: 1 addition & 1 deletion integration/typescript/types.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ApiError, Hardware, Model, ModelVersion, Page, Prediction, Status, Training, Visibility, WebhookEventType } from "replicate";
import { ApiError, Collection, Hardware, Model, ModelVersion, Page, Prediction, Status, Training, Visibility, WebhookEventType } from "replicate";

// NOTE: We export the constants to avoid unused varaible issues.

Expand Down
6 changes: 6 additions & 0 deletions lib/collections.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
/** @typedef {import("./types").Collection} Collection */
/**
* @template T
* @typedef {import("./types").Page<T>} Page
*/

/**
* Fetch a model collection
*
Expand Down
4 changes: 3 additions & 1 deletion lib/deployments.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
/** @typedef {import("./types").Prediction} Prediction */

/**
* Create a new prediction with a deployment
*
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
* @param {string} deployment_name - Required. The name of the deployment
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {unknown} options.input - Required. An object with the model inputs
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
Expand Down
1 change: 1 addition & 0 deletions lib/hardware.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/** @typedef {import("./types").Hardware} Hardware */
/**
* List hardware
*
Expand Down
12 changes: 6 additions & 6 deletions lib/identifier.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
* A reference to a model version in the format `owner/name` or `owner/name:version`.
*/
class ModelVersionIdentifier {
/*
* @param {string} Required. The model owner.
* @param {string} Required. The model name.
* @param {string} The model version.
/**
* @param {string} owner Required. The model owner.
* @param {string} name Required. The model name.
* @param {string | null=} version The model version.
*/
constructor(owner, name, version = null) {
this.owner = owner;
this.name = name;
this.version = version;
}

/*
/**
* Parse a reference to a model version
*
* @param {string}
* @param {string} ref
* @returns {ModelVersionIdentifier}
* @throws {Error} If the reference is invalid.
*/
Expand Down
11 changes: 10 additions & 1 deletion lib/models.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
/** @typedef {import("./types").Model} Model */
/** @typedef {import("./types").ModelVersion} ModelVersion */
/** @typedef {import("./types").Prediction} Prediction */
/** @typedef {import("./types").Visibility} Visibility */
/**
* @template T
* @typedef {import("./types").Page<T>} Page
*/

/**
* Get information about a model
*
Expand Down Expand Up @@ -69,7 +78,7 @@ async function listModels() {
* @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization.
* @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization.
* @param {object} options
* @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
* @param {Visibility} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
* @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`.
* @param {string} options.description - A description of the model.
* @param {string=} options.github_url - A URL for the model's source code on GitHub.
Expand Down
30 changes: 23 additions & 7 deletions lib/predictions.js
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
/**
* @template T
* @typedef {import("./types").Page<T>} Page
*/
/** @typedef {import("./types").Prediction} Prediction */

/**
* @typedef {Object} BasePredictionOptions
* @property {unknown} input - Required. An object with the model inputs
* @property {string} [webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @property {string[]} [webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @property {boolean} [stream] - Whether to stream the prediction output. Defaults to false
*
* @typedef {Object} ModelPredictionOptions
* @property {string} model The model name (for official models)
* @property {never=} version
*
* @typedef {Object} VersionPredictionOptions
* @property {string} version The model version
* @property {never=} model
*/

/**
* Create a new prediction
*
* @param {object} options
* @param {string=} options.model - The model (for official models)
* @param {string=} options.version - The model version.
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
* @param {BasePredictionOptions & (ModelPredictionOptions | VersionPredictionOptions)} options
* @returns {Promise<Prediction>} Resolves with the created prediction
*/
async function createPrediction(options) {
Expand Down
39 changes: 31 additions & 8 deletions lib/replicate.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
/**
* @template T
* @typedef {import("./types").Page<T>} Page
*/

/** @typedef {import("./types").Prediction} Prediction */
/** @typedef {import("./types").WebhookEventType} WebhookEventType */

const ApiError = require("./error");
const ModelVersionIdentifier = require("./identifier");
const { Stream } = require("./stream");
Expand Down Expand Up @@ -49,34 +57,45 @@ module.exports = class Replicate {
* const input = {text: 'Hello, world!'}
* const output = await replicate.run(model, { input });
*
* @param {Object={}} options - Configuration options for the client
* @param {Object} [options] - Configuration options for the client
* @param {string} [options.auth] - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable.
* @param {string} [options.userAgent] - Identifier of your app
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
*/
constructor(options = {}) {
/** @type {string} */
this.auth = options.auth || process.env.REPLICATE_API_TOKEN;

/** @type {string} */
this.userAgent =
options.userAgent || `replicate-javascript/${packageJSON.version}`;

/** @type {string} */
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";

/** @type {fetch} */
this.fetch = options.fetch || globalThis.fetch;

/** @type {collections} */
this.collections = {
list: collections.list.bind(this),
get: collections.get.bind(this),
};

/** @type {deployments} */
this.deployments = {
predictions: {
create: deployments.predictions.create.bind(this),
},
};

/** @type {hardware} */
this.hardware = {
list: hardware.list.bind(this),
};

/** @type {models} */
this.models = {
get: models.get.bind(this),
list: models.list.bind(this),
Expand All @@ -87,13 +106,15 @@ module.exports = class Replicate {
},
};

/** @type {predictions} */
this.predictions = {
create: predictions.create.bind(this),
get: predictions.get.bind(this),
cancel: predictions.cancel.bind(this),
list: predictions.list.bind(this),
};

/** @type {trainings} */
this.trainings = {
create: trainings.create.bind(this),
get: trainings.get.bind(this),
Expand All @@ -105,18 +126,18 @@ module.exports = class Replicate {
/**
* Run a model and wait for its output.
*
* @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
* @param {`${string}/${string}` | `${string}/${string}:${string}`} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {object} [options.wait] - Options for waiting for the prediction to finish
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
* @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed.
* @throws {Error} If the reference is invalid
* @throws {Error} If the prediction failed
* @returns {Promise<object>} - Resolves with the output of running the model
* @returns {Promise<Prediction>} - Resolves with the output of running the model
*/
async run(ref, options, progress) {
const { wait, ...data } = options;
Expand Down Expand Up @@ -252,7 +273,7 @@ module.exports = class Replicate {
/**
* Stream a model and wait for its output.
*
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
* @param {string} ref - Required. The model version identifier in the format "{owner}/{name}:{version}"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
Expand Down Expand Up @@ -300,8 +321,10 @@ module.exports = class Replicate {
* for await (const page of replicate.paginate(replicate.predictions.list) {
* console.log(page);
* }
* @param {Function} endpoint - Function that returns a promise for the next page of results
* @yields {object[]} Each page of results
* @template T
* @param {() => Promise<Page<T>>} endpoint - Function that returns a promise for the next page of results
* @yields {T[]} Each page of results
* @returns {AsyncGenerator<T[], void, unknown>}
*/
async *paginate(endpoint) {
const response = await endpoint();
Expand All @@ -327,7 +350,7 @@ module.exports = class Replicate {
* @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling.
* @throws {Error} If the prediction doesn't complete within the maximum number of attempts
* @throws {Error} If the prediction failed
* @returns {Promise<object>} Resolves with the completed prediction object
* @returns {Promise<Prediction>} Resolves with the completed prediction object
*/
async wait(prediction, options, stop) {
const { id } = prediction;
Expand Down
10 changes: 9 additions & 1 deletion lib/stream.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Attempt to use readable-stream if available, attempt to use the built-in stream module.
/** @type {import("stream").Readable} */
let Readable;
try {
Readable = require("readable-stream").Readable;
Expand Down Expand Up @@ -49,7 +50,7 @@ class Stream extends Readable {
* Create a new stream of server-sent events.
*
* @param {string} url The URL to connect to.
* @param {object} options The fetch options.
* @param {RequestInit=} options The fetch options.
*/
constructor(url, options) {
if (!Readable) {
Expand All @@ -63,11 +64,18 @@ class Stream extends Readable {
this.options = options;

this.event = null;

/** @type {unknown[]} */
this.data = [];

/** @type {string | null} */
this.lastEventId = null;

/** @type {number | null} */
this.retry = null;
}

/** @param {string=} line */
decode(line) {
if (!line) {
if (!this.event && !this.data.length && !this.lastEventId) {
Expand Down
Loading

0 comments on commit 1195d8a

Please sign in to comment.