diff --git a/packages/hub/package.json b/packages/hub/package.json index 3015df1dc..0f08adb56 100644 --- a/packages/hub/package.json +++ b/packages/hub/package.json @@ -18,7 +18,7 @@ } }, "browser": { - "./src/utils/sha256-node.ts": false, + "./src/utils/sha-node.ts": false, "./src/utils/FileBlob.ts": false, "./dist/index.js": "./dist/browser/index.js", "./dist/index.mjs": "./dist/browser/index.mjs" diff --git a/packages/hub/src/lib/commit.spec.ts b/packages/hub/src/lib/commit.spec.ts index 28d46b924..de4b2810d 100644 --- a/packages/hub/src/lib/commit.spec.ts +++ b/packages/hub/src/lib/commit.spec.ts @@ -2,7 +2,7 @@ import { assert, it, describe } from "vitest"; import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts"; import type { RepoId } from "../types/public"; -import type { CommitFile } from "./commit"; +import type { CommitFile, CommitOperation } from "./commit"; import { commit } from "./commit"; import { createRepo } from "./create-repo"; import { deleteRepo } from "./delete-repo"; @@ -127,6 +127,97 @@ size ${lfsContent.length} } }, 60_000); + it("should not commit if nothing to add/update", async function () { + const tokenizerJsonUrl = new URL( + "https://huggingface.co/spaces/aschen/push-model-from-web/raw/main/mobilenet/model.json" + ); + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo: RepoId = { + name: repoName, + type: "model", + }; + + await createRepo({ + credentials: { + accessToken: TEST_ACCESS_TOKEN, + }, + hubUrl: TEST_HUB_URL, + repo, + license: "mit", + }); + + try { + const readme1 = await downloadFile({ repo, path: "README.md", hubUrl: TEST_HUB_URL }); + assert.strictEqual(readme1?.status, 200); + + const nodeOperation: CommitFile[] = isFrontend + ? [] + : [ + { + operation: "addOrUpdate", + path: "tsconfig.json", + content: (await import("node:url")).pathToFileURL("./tsconfig.json") as URL, + }, + ]; + + const operations: CommitOperation[] = [ + { + operation: "addOrUpdate", + content: new Blob(["This is me"]), + path: "test.txt", + }, + { + operation: "addOrUpdate", + content: new Blob([lfsContent]), + path: "test.lfs.txt", + }, + ...nodeOperation, + { + operation: "addOrUpdate", + content: tokenizerJsonUrl, + path: "lamaral.json", + }, + ]; + + const firstOutput = await commit({ + repo, + title: "Preparation commit", + credentials: { + accessToken: TEST_ACCESS_TOKEN, + }, + hubUrl: TEST_HUB_URL, + operations, + // To test web workers in the front-end + useWebWorkers: { minSize: 5_000 }, + }); + + const secondOutput = await commit({ + repo, + title: "Empty commit", + credentials: { + accessToken: TEST_ACCESS_TOKEN, + }, + hubUrl: TEST_HUB_URL, + operations, + // To test web workers in the front-end + useWebWorkers: { minSize: 5_000 }, + }); + + assert.deepStrictEqual(firstOutput.commit, secondOutput.commit); + assert.strictEqual(secondOutput.hookOutput, "Nothing to commit"); + + const currentRes: Response = await fetch(`${TEST_HUB_URL}/api/${repo.type}s/${repo.name}`); + const current = await currentRes.json(); + assert.strictEqual(secondOutput.commit.oid, current.sha); + } finally { + await deleteRepo({ + repo, + hubUrl: TEST_HUB_URL, + credentials: { accessToken: TEST_ACCESS_TOKEN }, + }); + } + }, 60_000); + it("should commit a full repo from HF with web urls", async function () { const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; const repo: RepoId = { diff --git a/packages/hub/src/lib/commit.ts b/packages/hub/src/lib/commit.ts index a66af74c1..45ff66059 100644 --- a/packages/hub/src/lib/commit.ts +++ b/packages/hub/src/lib/commit.ts @@ -15,6 +15,7 @@ import { checkCredentials } from "../utils/checkCredentials"; import { chunk } from "../utils/chunk"; import { promisesQueue } from "../utils/promisesQueue"; import { promisesQueueStreaming } from "../utils/promisesQueueStreaming"; +import { sha1 } from "../utils/sha1"; import { sha256 } from "../utils/sha256"; import { toRepoId } from "../utils/toRepoId"; import { WebBlob } from "../utils/WebBlob"; @@ -125,8 +126,6 @@ export async function* commitIter(params: CommitParams): AsyncGenerator(); - const abortController = new AbortController(); const abortSignal = abortController.signal; @@ -168,6 +167,10 @@ export async function* commitIter(params: CommitParams): AsyncGenerator op.path === ".gitattributes")?.content; + let currentCommitOid: string | undefined; + const deleteOps = allOperations.filter((op): op is CommitDeletedEntry => !isFileOperation(op)); + const fileOpsToCommit: (CommitBlob & + ({ uploadMode: "lfs"; sha: string } | { uploadMode: "regular"; newSha: string | null }))[] = []; for (const operations of chunk(allOperations.filter(isFileOperation), 100)) { const payload: ApiPreuploadRequest = { gitAttributes: gitAttributes && (await gitAttributes.text()), @@ -202,44 +205,53 @@ export async function* commitIter(params: CommitParams): AsyncGenerator( + (yieldCallback, returnCallback, rejectCallack) => { + return promisesQueue( + json.files.map((file) => async () => { + const op = operations.find((op) => op.path === file.path); + if (!op) { + return; // this should never happen, server-side we always return one entry per operation + } - for (const file of json.files) { - if (file.uploadMode === "lfs") { - lfsShas.set(file.path, null); + const iterator = + file.uploadMode === "lfs" + ? sha256(op.content, { useWebWorker: params.useWebWorkers, abortSignal }) + : sha1(new Blob([`blob ${op.content.size}\0`, await op.content.arrayBuffer()]), { abortSignal }); // https://git-scm.com/book/en/v2/Git-Internals-Git-Objects + + let res: IteratorResult; + do { + res = await iterator.next(); + if (!res.done) { + yieldCallback({ event: "fileProgress", path: op.path, progress: res.value, state: "hashing" }); + } + } while (!res.done); + + if (res.value === null || res.value !== file.oid) { + fileOpsToCommit.push({ + ...op, + uploadMode: file.uploadMode, + sha: res.value, + } as (typeof fileOpsToCommit)[number]); + // ^^ we know `newSha` can only be `null` in the case of `sha1` as expected, but TS doesn't + } + }), + CONCURRENT_SHAS + ).then(returnCallback, rejectCallack); } - } + ); + + abortSignal?.throwIfAborted(); } yield { event: "phase", phase: "uploadingLargeFiles" }; - for (const operations of chunk( - allOperations.filter(isFileOperation).filter((op) => lfsShas.has(op.path)), + for (const lfsOpsToCommit of chunk( + fileOpsToCommit.filter((op): op is CommitBlob & { uploadMode: "lfs"; sha: string } => op.uploadMode === "lfs"), 100 )) { - const shas = yield* eventToGenerator< - { event: "fileProgress"; state: "hashing"; path: string; progress: number }, - string[] - >((yieldCallback, returnCallback, rejectCallack) => { - return promisesQueue( - operations.map((op) => async () => { - const iterator = sha256(op.content, { useWebWorker: params.useWebWorkers, abortSignal: abortSignal }); - let res: IteratorResult; - do { - res = await iterator.next(); - if (!res.done) { - yieldCallback({ event: "fileProgress", path: op.path, progress: res.value, state: "hashing" }); - } - } while (!res.done); - const sha = res.value; - lfsShas.set(op.path, res.value); - return sha; - }), - CONCURRENT_SHAS - ).then(returnCallback, rejectCallack); - }); - - abortSignal?.throwIfAborted(); - const payload: ApiLfsBatchRequest = { operation: "upload", // multipart is a custom protocol for HF @@ -250,8 +262,8 @@ export async function* commitIter(params: CommitParams): AsyncGenerator ({ - oid: shas[i], + objects: lfsOpsToCommit.map((op) => ({ + oid: op.sha, size: op.content.size, })), }; @@ -279,7 +291,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator [shas[i], op])); + const shaToOperation = new Map(lfsOpsToCommit.map((op) => [op.sha, op])); yield* eventToGenerator((yieldCallback, returnCallback, rejectCallback) => { return promisesQueueStreaming( @@ -293,9 +305,9 @@ export async function* commitIter(params: CommitParams): AsyncGenerator( async (yieldCallback, returnCallback, rejectCallback) => (params.fetch ?? fetch)( @@ -481,17 +521,16 @@ export async function* commitIter(params: CommitParams): AsyncGenerator { + allOpsToCommit.map((operation) => { if (isFileOperation(operation)) { - const sha = lfsShas.get(operation.path); - if (sha) { + if (operation.uploadMode === "lfs") { return { key: "lfsFile", value: { path: operation.path, algo: "sha256", size: operation.content.size, - oid: sha, + oid: operation.sha, } satisfies ApiCommitLfsFile, }; } @@ -509,8 +548,8 @@ export async function* commitIter(params: CommitParams): AsyncGenerator { // For now, we display equal progress for all files // We could compute the progress based on the size of `convertOperationToNdJson` for each of the files instead - for (const op of allOperations) { - if (isFileOperation(op) && !lfsShas.has(op.path)) { + for (const op of allOpsToCommit) { + if (isFileOperation(op) && op.uploadMode === "regular") { yieldCallback({ event: "fileProgress", path: op.path, diff --git a/packages/hub/src/lib/upload-files-with-progress.spec.ts b/packages/hub/src/lib/upload-files-with-progress.spec.ts index bc082cd89..b0d619f5b 100644 --- a/packages/hub/src/lib/upload-files-with-progress.spec.ts +++ b/packages/hub/src/lib/upload-files-with-progress.spec.ts @@ -70,6 +70,7 @@ describe("uploadFilesWithProgress", () => { // assert(intermediateUploadEvents.length > 0, "There should be at least one intermediate upload event"); // } progressEvents = progressEvents.filter((e) => e.event !== "fileProgress" || e.progress === 0 || e.progress === 1); + const hashingEvents = progressEvents.splice(1, 6); assert.deepStrictEqual(progressEvents, [ { @@ -84,29 +85,57 @@ describe("uploadFilesWithProgress", () => { event: "fileProgress", path: "test.lfs.txt", progress: 0, - state: "hashing", + state: "uploading", }, { event: "fileProgress", path: "test.lfs.txt", progress: 1, + state: "uploading", + }, + { + event: "phase", + phase: "committing", + }, + ]); + + // order can very due to concurrent processing + assert.includeDeepMembers(hashingEvents, [ + { + event: "fileProgress", + path: "file1", + progress: 0, state: "hashing", }, { event: "fileProgress", - path: "test.lfs.txt", + path: "file1", + progress: 1, + state: "hashing", + }, + { + event: "fileProgress", + path: "config.json", progress: 0, - state: "uploading", + state: "hashing", }, { event: "fileProgress", - path: "test.lfs.txt", + path: "config.json", progress: 1, - state: "uploading", + state: "hashing", }, { - event: "phase", - phase: "committing", + event: "fileProgress", + path: "test.lfs.txt", + progress: 0, + state: "hashing", + }, + { + event: "fileProgress", + path: "test.lfs.txt", + progress: 1, + state: "hashing", }, ]); diff --git a/packages/hub/src/types/api/api-commit.ts b/packages/hub/src/types/api/api-commit.ts index 404ad5494..2d2f97616 100644 --- a/packages/hub/src/types/api/api-commit.ts +++ b/packages/hub/src/types/api/api-commit.ts @@ -127,9 +127,18 @@ export interface ApiPreuploadRequest { } export interface ApiPreuploadResponse { + /** + * Most recent commit oid for target rev + */ + commitOid: string; files: Array<{ path: string; uploadMode: "lfs" | "regular"; + /** + * The oid of the blob if it already exists in the repository + * in case of blob is a lfs file, it'll be the lfs file's sha256 + */ + oid?: string; }>; } diff --git a/packages/hub/src/utils/sha256-node.ts b/packages/hub/src/utils/sha-node.ts similarity index 55% rename from packages/hub/src/utils/sha256-node.ts rename to packages/hub/src/utils/sha-node.ts index b068d1a21..7eb3cac75 100644 --- a/packages/hub/src/utils/sha256-node.ts +++ b/packages/hub/src/utils/sha-node.ts @@ -8,19 +8,38 @@ export async function* sha256Node( abortSignal?: AbortSignal; } ): AsyncGenerator { - const sha256Stream = createHash("sha256"); + return yield* shaNode("sha256", buffer, opts); +} + +export async function* sha1Node( + buffer: ArrayBuffer | Blob, + opts?: { + abortSignal?: AbortSignal; + } +): AsyncGenerator { + return yield* shaNode("sha1", buffer, opts); +} + +async function* shaNode( + algorithm: "sha1" | "sha256", + buffer: ArrayBuffer | Blob, + opts?: { + abortSignal?: AbortSignal; + } +): AsyncGenerator { + const shaStream = createHash(algorithm); const size = buffer instanceof Blob ? buffer.size : buffer.byteLength; let done = 0; const readable = buffer instanceof Blob ? Readable.fromWeb(buffer.stream() as ReadableStream) : Readable.from(Buffer.from(buffer)); for await (const buffer of readable) { - sha256Stream.update(buffer); + shaStream.update(buffer); done += buffer.length; yield done / size; opts?.abortSignal?.throwIfAborted(); } - return sha256Stream.digest("hex"); + return shaStream.digest("hex"); } diff --git a/packages/hub/src/utils/sha1.ts b/packages/hub/src/utils/sha1.ts new file mode 100644 index 000000000..5e16fd5d7 --- /dev/null +++ b/packages/hub/src/utils/sha1.ts @@ -0,0 +1,29 @@ +import { hexFromBytes } from "./hexFromBytes"; +import { isFrontend } from "./isFrontend"; + +export async function* sha1(buffer: Blob, opts?: { abortSignal?: AbortSignal }): AsyncGenerator { + yield 0; + + if (globalThis.crypto?.subtle) { + const res = hexFromBytes( + new Uint8Array(await globalThis.crypto.subtle.digest("SHA-1", await buffer.arrayBuffer())) + ); + + yield 1; + return res; + } + + if (isFrontend) { + yield 1; + return null; // unsupported if we're here + } + + if (!cryptoModule) { + cryptoModule = await import("./sha-node"); + } + + return yield* cryptoModule.sha1Node(buffer, { abortSignal: opts?.abortSignal }); +} + +// eslint-disable-next-line @typescript-eslint/consistent-type-imports +let cryptoModule: typeof import("./sha-node"); diff --git a/packages/hub/src/utils/sha256.ts b/packages/hub/src/utils/sha256.ts index d432909b3..6bc2a40ae 100644 --- a/packages/hub/src/utils/sha256.ts +++ b/packages/hub/src/utils/sha256.ts @@ -154,13 +154,13 @@ export async function* sha256( } if (!cryptoModule) { - cryptoModule = await import("./sha256-node"); + cryptoModule = await import("./sha-node"); } return yield* cryptoModule.sha256Node(buffer, { abortSignal: opts?.abortSignal }); } // eslint-disable-next-line @typescript-eslint/consistent-type-imports -let cryptoModule: typeof import("./sha256-node"); +let cryptoModule: typeof import("./sha-node"); // eslint-disable-next-line @typescript-eslint/consistent-type-imports let wasmModule: typeof import("../vendor/hash-wasm/sha256-wrapper");