Skip to content

Commit

Permalink
Add modelInfo, spaceInfo, datasetInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
coyotte508 committed Oct 5, 2024
1 parent 22e6d4d commit 70ae58f
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 13 deletions.
4 changes: 3 additions & 1 deletion packages/hub/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ For some of the calls, you need to create an account and generate an [access tok
Learn how to find free models using the hub package in this [interactive tutorial](https://scrimba.com/scrim/c7BbVPcd?pl=pkVnrP7uP).

```ts
import { createRepo, uploadFiles, uploadFilesWithProgress, deleteFile, deleteRepo, listFiles, whoAmI } from "@huggingface/hub";
import { createRepo, uploadFiles, uploadFilesWithProgress, deleteFile, deleteRepo, listFiles, whoAmI, modelInfo, listMOdels } from "@huggingface/hub";
import type { RepoDesignation } from "@huggingface/hub";

const repo: RepoDesignation = { type: "model", name: "myname/some-model" };
Expand All @@ -41,6 +41,8 @@ for await (const model of listModels({search: {owner: username}, accessToken: "h
console.log("My model:", model);
}

const specificModel = await modelInfo({name: "openai-community/gpt2"});

await createRepo({ repo, accessToken: "hf_...", license: "mit" });

await uploadFiles({
Expand Down
19 changes: 19 additions & 0 deletions packages/hub/src/lib/dataset-info.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { describe, expect, it } from "vitest";
import { datasetInfo } from "./dataset-info";

describe("datasetInfo", () => {
it("should return the dataset info", async () => {
const info = await datasetInfo({
name: "nyu-mll/glue",
});
expect(info).toEqual({
id: "621ffdd236468d709f181e3f",
downloads: expect.any(Number),
gated: false,
name: "nyu-mll/glue",
updatedAt: expect.any(Date),
likes: expect.any(Number),
private: false,
});
});
});
59 changes: 59 additions & 0 deletions packages/hub/src/lib/dataset-info.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiDatasetInfo } from "../types/api/api-dataset";
import type { CredentialsParams } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { pick } from "../utils/pick";
import { type DATASET_EXPANDABLE_KEYS, DATASET_EXPAND_KEYS, type DatasetEntry } from "./list-datasets";

export async function datasetInfo<
const T extends Exclude<(typeof DATASET_EXPANDABLE_KEYS)[number], (typeof DATASET_EXPAND_KEYS)[number]> = never,
>(
params: {
name: string;
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<DatasetEntry & Pick<ApiDatasetInfo, T>> {
const accessToken = params && checkCredentials(params);

const search = new URLSearchParams([
...DATASET_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
]).toString();

const response = await (params.fetch || fetch)(
`${params?.hubUrl || HUB_URL}/api/datasets/${params.name}?${search.toString()}`,
{
headers: {
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
Accepts: "application/json",
},
}
);

if (!response.ok) {
createApiError(response);
}

const data = await response.json();

return {
...(params?.additionalFields && pick(data, params.additionalFields)),
id: data._id,
name: data.id,
private: data.private,
downloads: data.downloads,
likes: data.likes,
gated: data.gated,
updatedAt: new Date(data.lastModified),
} as DatasetEntry & Pick<ApiDatasetInfo, T>;
}
3 changes: 3 additions & 0 deletions packages/hub/src/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export * from "./cache-management";
export * from "./commit";
export * from "./count-commits";
export * from "./create-repo";
export * from "./dataset-info";
export * from "./delete-file";
export * from "./delete-files";
export * from "./delete-repo";
Expand All @@ -13,9 +14,11 @@ export * from "./list-datasets";
export * from "./list-files";
export * from "./list-models";
export * from "./list-spaces";
export * from "./model-info";
export * from "./oauth-handle-redirect";
export * from "./oauth-login-url";
export * from "./parse-safetensors-metadata";
export * from "./space-info";
export * from "./upload-file";
export * from "./upload-files";
export * from "./upload-files-with-progress";
Expand Down
8 changes: 4 additions & 4 deletions packages/hub/src/lib/list-datasets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";

const EXPAND_KEYS = [
export const DATASET_EXPAND_KEYS = [
"private",
"downloads",
"gated",
"likes",
"lastModified",
] as const satisfies readonly (keyof ApiDatasetInfo)[];

const EXPANDABLE_KEYS = [
export const DATASET_EXPANDABLE_KEYS = [
"author",
"cardData",
"citation",
Expand Down Expand Up @@ -45,7 +45,7 @@ export interface DatasetEntry {
}

export async function* listDatasets<
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
const T extends Exclude<(typeof DATASET_EXPANDABLE_KEYS)[number], (typeof DATASET_EXPAND_KEYS)[number]> = never,
>(
params?: {
search?: {
Expand Down Expand Up @@ -77,7 +77,7 @@ export async function* listDatasets<
...(params?.search?.query ? { search: params.search.query } : undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...DATASET_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/datasets` + (search ? "?" + search : "");
Expand Down
8 changes: 4 additions & 4 deletions packages/hub/src/lib/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";

const EXPAND_KEYS = [
export const MODEL_EXPAND_KEYS = [
"pipeline_tag",
"private",
"gated",
Expand All @@ -15,7 +15,7 @@ const EXPAND_KEYS = [
"lastModified",
] as const satisfies readonly (keyof ApiModelInfo)[];

const EXPANDABLE_KEYS = [
export const MODEL_EXPANDABLE_KEYS = [
"author",
"cardData",
"config",
Expand Down Expand Up @@ -51,7 +51,7 @@ export interface ModelEntry {
}

export async function* listModels<
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never,
>(
params?: {
search?: {
Expand Down Expand Up @@ -85,7 +85,7 @@ export async function* listModels<
...(params?.search?.query ? { search: params.search.query } : undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;
Expand Down
15 changes: 11 additions & 4 deletions packages/hub/src/lib/list-spaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";

const EXPAND_KEYS = ["sdk", "likes", "private", "lastModified"] as const satisfies readonly (keyof ApiSpaceInfo)[];
const EXPANDABLE_KEYS = [
export const SPACE_EXPAND_KEYS = [
"sdk",
"likes",
"private",
"lastModified",
] as const satisfies readonly (keyof ApiSpaceInfo)[];
export const SPACE_EXPANDABLE_KEYS = [
"author",
"cardData",
"datasets",
Expand Down Expand Up @@ -37,7 +42,7 @@ export interface SpaceEntry {
}

export async function* listSpaces<
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
const T extends Exclude<(typeof SPACE_EXPANDABLE_KEYS)[number], (typeof SPACE_EXPAND_KEYS)[number]> = never,
>(
params?: {
search?: {
Expand Down Expand Up @@ -67,7 +72,9 @@ export async function* listSpaces<
...(params?.search?.query ? { search: params.search.query } : undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...[...EXPAND_KEYS, ...(params?.additionalFields ?? [])].map((val) => ["expand", val] satisfies [string, string]),
...[...SPACE_EXPAND_KEYS, ...(params?.additionalFields ?? [])].map(
(val) => ["expand", val] satisfies [string, string]
),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/spaces?${search}`;

Expand Down
20 changes: 20 additions & 0 deletions packages/hub/src/lib/model-info.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { describe, expect, it } from "vitest";
import { modelInfo } from "./model-info";

describe("modelInfo", () => {
it("should return the model info", async () => {
const info = await modelInfo({
name: "openai-community/gpt2",
});
expect(info).toEqual({
id: "621ffdc036468d709f17434d",
downloads: expect.any(Number),
gated: false,
name: "openai-community/gpt2",
updatedAt: expect.any(Date),
likes: expect.any(Number),
task: "text-generation",
private: false,
});
});
});
60 changes: 60 additions & 0 deletions packages/hub/src/lib/model-info.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiModelInfo } from "../types/api/api-model";
import type { CredentialsParams } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { pick } from "../utils/pick";
import { MODEL_EXPAND_KEYS, type MODEL_EXPANDABLE_KEYS, type ModelEntry } from "./list-models";

export async function modelInfo<
const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPANDABLE_KEYS)[number]> = never,
>(
params: {
name: string;
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<ModelEntry & Pick<ApiModelInfo, T>> {
const accessToken = params && checkCredentials(params);

const search = new URLSearchParams([
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
]).toString();

const response = await (params.fetch || fetch)(
`${params?.hubUrl || HUB_URL}/api/models/${params.name}?${search.toString()}`,
{
headers: {
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
Accepts: "application/json",
},
}
);

if (!response.ok) {
createApiError(response);
}

const data = await response.json();

return {
...(params?.additionalFields && pick(data, params.additionalFields)),
id: data._id,
name: data.id,
private: data.private,
task: data.pipeline_tag,
downloads: data.downloads,
gated: data.gated,
likes: data.likes,
updatedAt: new Date(data.lastModified),
} as ModelEntry & Pick<ApiModelInfo, T>;
}
18 changes: 18 additions & 0 deletions packages/hub/src/lib/space-info.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { describe, expect, it } from "vitest";
import { spaceInfo } from "./space-info";

describe("spaceInfo", () => {
it("should return the space info", async () => {
const info = await spaceInfo({
name: "huggingfacejs/client-side-oauth",
});
expect(info).toEqual({
id: "659835e689010f9c7aed608d",
name: "huggingfacejs/client-side-oauth",
updatedAt: expect.any(Date),
likes: expect.any(Number),
private: false,
sdk: "static",
});
});
});
59 changes: 59 additions & 0 deletions packages/hub/src/lib/space-info.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiSpaceInfo } from "../types/api/api-space";
import type { CredentialsParams } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { pick } from "../utils/pick";
import type { SPACE_EXPANDABLE_KEYS, SpaceEntry } from "./list-spaces";
import { SPACE_EXPAND_KEYS } from "./list-spaces";

export async function spaceInfo<
const T extends Exclude<(typeof SPACE_EXPANDABLE_KEYS)[number], (typeof SPACE_EXPAND_KEYS)[number]> = never,
>(
params: {
name: string;
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<SpaceEntry & Pick<ApiSpaceInfo, T>> {
const accessToken = params && checkCredentials(params);

const search = new URLSearchParams([
...SPACE_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
]).toString();

const response = await (params.fetch || fetch)(
`${params?.hubUrl || HUB_URL}/api/spaces/${params.name}?${search.toString()}`,
{
headers: {
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
Accepts: "application/json",
},
}
);

if (!response.ok) {
createApiError(response);
}

const data = await response.json();

return {
...(params?.additionalFields && pick(data, params.additionalFields)),
id: data._id,
name: data.id,
sdk: data.sdk,
likes: data.likes,
private: data.private,
updatedAt: new Date(data.lastModified),
} as SpaceEntry & Pick<ApiSpaceInfo, T>;
}

0 comments on commit 70ae58f

Please sign in to comment.