Skip to content

Commit

Permalink
💥 Simpler credentials passing around
Browse files Browse the repository at this point in the history
  • Loading branch information
coyotte508 committed Sep 20, 2024
1 parent 2b93cef commit 32204a5
Show file tree
Hide file tree
Showing 30 changed files with 460 additions and 446 deletions.
17 changes: 8 additions & 9 deletions packages/hub/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,21 @@ Learn how to find free models using the hub package in this [interactive tutoria

```ts
import { createRepo, uploadFiles, uploadFilesWithProgress, deleteFile, deleteRepo, listFiles, whoAmI } from "@huggingface/hub";
import type { RepoDesignation, Credentials } from "@huggingface/hub";
import type { RepoDesignation } from "@huggingface/hub";

const repo: RepoDesignation = { type: "model", name: "myname/some-model" };
const credentials: Credentials = { accessToken: "hf_..." };

const {name: username} = await whoAmI({credentials});
const {name: username} = await whoAmI({accessToken: "hf_..."});

for await (const model of listModels({search: {owner: username}, credentials})) {
for await (const model of listModels({search: {owner: username}, accessToken: "hf_..."})) {
console.log("My model:", model);
}

await createRepo({ repo, credentials, license: "mit" });
await createRepo({ repo, accessToken: "hf_...", license: "mit" });

await uploadFiles({
repo,
credentials,
accessToken: "hf_...",
files: [
// path + blob content
{
Expand All @@ -70,23 +69,23 @@ await uploadFiles({

for await (const progressEvent of await uploadFilesWithProgress({
repo,
credentials,
accessToken: "hf_...",
files: [
...
],
})) {
console.log(progressEvent);
}

await deleteFile({repo, credentials, path: "myfile.bin"});
await deleteFile({repo, accessToken: "hf_...", path: "myfile.bin"});

await (await downloadFile({ repo, path: "README.md" })).text();

for await (const fileInfo of listFiles({repo})) {
console.log(fileInfo);
}

await deleteRepo({ repo, credentials });
await deleteRepo({ repo, accessToken: "hf_..." });
```

## OAuth Login
Expand Down
16 changes: 4 additions & 12 deletions packages/hub/src/lib/commit.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ describe("commit", () => {
};

await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo,
license: "mit",
Expand All @@ -50,9 +48,7 @@ describe("commit", () => {
await commit({
repo,
title: "Some commit",
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
operations: [
{
Expand Down Expand Up @@ -135,9 +131,7 @@ size ${lfsContent.length}
};

await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
repo,
hubUrl: TEST_HUB_URL,
});
Expand All @@ -163,9 +157,7 @@ size ${lfsContent.length}
);
await commit({
repo,
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
title: "upload model",
operations,
Expand Down
16 changes: 8 additions & 8 deletions packages/hub/src/lib/commit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import type {
ApiPreuploadRequest,
ApiPreuploadResponse,
} from "../types/api/api-commit";
import type { Credentials, RepoDesignation } from "../types/public";
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { chunk } from "../utils/chunk";
import { promisesQueue } from "../utils/promisesQueue";
Expand Down Expand Up @@ -54,12 +54,11 @@ type CommitBlob = Omit<CommitFile, "content"> & { content: Blob };
export type CommitOperation = CommitDeletedEntry | CommitFile /* | CommitRenameFile */;
type CommitBlobOperation = Exclude<CommitOperation, CommitFile> | CommitBlob;

export interface CommitParams {
export type CommitParams = {
title: string;
description?: string;
repo: RepoDesignation;
operations: CommitOperation[];
credentials?: Credentials;
/** @default "main" */
branch?: string;
/**
Expand All @@ -82,7 +81,8 @@ export interface CommitParams {
*/
fetch?: typeof fetch;
abortSignal?: AbortSignal;
}
// Credentials are optional due to custom fetch functions or cookie auth
} & Partial<CredentialsParams>;

export interface CommitOutput {
pullRequestUrl?: string;
Expand Down Expand Up @@ -121,7 +121,7 @@ export type CommitProgressEvent =
* Can be exposed later to offer fine-tuned progress info
*/
export async function* commitIter(params: CommitParams): AsyncGenerator<CommitProgressEvent, CommitOutput> {
checkCredentials(params.credentials);
const accessToken = checkCredentials(params);
const repoId = toRepoId(params.repo);
yield { event: "phase", phase: "preuploading" };

Expand Down Expand Up @@ -189,7 +189,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
{
method: "POST",
headers: {
...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }),
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
"Content-Type": "application/json",
},
body: JSON.stringify(payload),
Expand Down Expand Up @@ -263,7 +263,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
{
method: "POST",
headers: {
...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }),
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
Accept: "application/vnd.git-lfs+json",
"Content-Type": "application/vnd.git-lfs+json",
},
Expand Down Expand Up @@ -468,7 +468,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
{
method: "POST",
headers: {
...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }),
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
"Content-Type": "application/x-ndjson",
},
body: [
Expand Down
27 changes: 14 additions & 13 deletions packages/hub/src/lib/count-commits.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { Credentials, RepoDesignation } from "../types/public";
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { toRepoId } from "../utils/toRepoId";

export async function countCommits(params: {
credentials?: Credentials;
repo: RepoDesignation;
/**
* Revision to list commits from. Defaults to the default branch.
*/
revision?: string;
hubUrl?: string;
fetch?: typeof fetch;
}): Promise<number> {
checkCredentials(params.credentials);
export async function countCommits(
params: {
repo: RepoDesignation;
/**
* Revision to list commits from. Defaults to the default branch.
*/
revision?: string;
hubUrl?: string;
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<number> {
const accessToken = checkCredentials(params);
const repoId = toRepoId(params.repo);

// Could upgrade to 1000 commits per page
Expand All @@ -23,7 +24,7 @@ export async function countCommits(params: {
}?limit=1`;

const res: Response = await (params.fetch ?? fetch)(url, {
headers: params.credentials ? { Authorization: `Bearer ${params.credentials.accessToken}` } : {},
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
});

if (!res.ok) {
Expand Down
12 changes: 3 additions & 9 deletions packages/hub/src/lib/create-repo.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ describe("createRepo", () => {
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;

const result = await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
repo: {
name: repoName,
type: "model",
Expand Down Expand Up @@ -62,9 +60,7 @@ describe("createRepo", () => {
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;

const result = await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo: repoName,
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
Expand All @@ -88,9 +84,7 @@ describe("createRepo", () => {
const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`;

const result = await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo: repoName,
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
Expand Down
41 changes: 21 additions & 20 deletions packages/hub/src/lib/create-repo.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiCreateRepoPayload } from "../types/api/api-create-repo";
import type { Credentials, RepoDesignation, SpaceSdk } from "../types/public";
import type { CredentialsParams, RepoDesignation, SpaceSdk } from "../types/public";
import { base64FromBytes } from "../utils/base64FromBytes";
import { checkCredentials } from "../utils/checkCredentials";
import { toRepoId } from "../utils/toRepoId";

export async function createRepo(params: {
repo: RepoDesignation;
credentials: Credentials;
private?: boolean;
license?: string;
/**
* Only a few lightweight files are supported at repo creation
*/
files?: Array<{ content: ArrayBuffer | Blob; path: string }>;
/** @required for when {@link repo.type} === "space" */
sdk?: SpaceSdk;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): Promise<{ repoUrl: string }> {
checkCredentials(params.credentials);
export async function createRepo(
params: {
repo: RepoDesignation;
private?: boolean;
license?: string;
/**
* Only a few lightweight files are supported at repo creation
*/
files?: Array<{ content: ArrayBuffer | Blob; path: string }>;
/** @required for when {@link repo.type} === "space" */
sdk?: SpaceSdk;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & CredentialsParams
): Promise<{ repoUrl: string }> {
const accessToken = checkCredentials(params);
const repoId = toRepoId(params.repo);
const [namespace, repoName] = repoId.name.split("/");

Expand Down Expand Up @@ -61,7 +62,7 @@ export async function createRepo(params: {
: undefined,
} satisfies ApiCreateRepoPayload),
headers: {
Authorization: `Bearer ${params.credentials.accessToken}`,
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
Expand Down
9 changes: 3 additions & 6 deletions packages/hub/src/lib/delete-file.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@ describe("deleteFile", () => {
it("should delete a file", async () => {
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
const repo = { type: "model", name: repoName } satisfies RepoId;
const credentials = {
accessToken: TEST_ACCESS_TOKEN,
};

try {
const result = await createRepo({
credentials,
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo,
files: [
Expand All @@ -39,7 +36,7 @@ describe("deleteFile", () => {

assert.strictEqual(await content?.text(), "file1");

await deleteFile({ path: "file1", repo, credentials, hubUrl: TEST_HUB_URL });
await deleteFile({ path: "file1", repo, accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL });

content = await downloadFile({
repo,
Expand All @@ -59,7 +56,7 @@ describe("deleteFile", () => {
} finally {
await deleteRepo({
repo,
credentials,
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
});
}
Expand Down
29 changes: 15 additions & 14 deletions packages/hub/src/lib/delete-file.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import type { Credentials } from "../types/public";
import type { CredentialsParams } from "../types/public";
import type { CommitOutput, CommitParams } from "./commit";
import { commit } from "./commit";

export function deleteFile(params: {
credentials: Credentials;
repo: CommitParams["repo"];
path: string;
commitTitle?: CommitParams["title"];
commitDescription?: CommitParams["description"];
hubUrl?: CommitParams["hubUrl"];
fetch?: CommitParams["fetch"];
branch?: CommitParams["branch"];
isPullRequest?: CommitParams["isPullRequest"];
parentCommit?: CommitParams["parentCommit"];
}): Promise<CommitOutput> {
export function deleteFile(
params: {
repo: CommitParams["repo"];
path: string;
commitTitle?: CommitParams["title"];
commitDescription?: CommitParams["description"];
hubUrl?: CommitParams["hubUrl"];
fetch?: CommitParams["fetch"];
branch?: CommitParams["branch"];
isPullRequest?: CommitParams["isPullRequest"];
parentCommit?: CommitParams["parentCommit"];
} & CredentialsParams
): Promise<CommitOutput> {
return commit({
credentials: params.credentials,
...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }),
repo: params.repo,
operations: [
{
Expand Down
Loading

0 comments on commit 32204a5

Please sign in to comment.