-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #807 from epfml/654-improve-gpt-julien
Fix and rework GPT-TF.js
- Loading branch information
Showing
28 changed files
with
835 additions
and
353 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
name: record-cypress | ||
on: | ||
workflow_dispatch: | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
download-datasets: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
lfs: true | ||
submodules: true | ||
- uses: actions/cache@v4 | ||
with: | ||
path: datasets | ||
key: datasets-${{ hashFiles('datasets/**') }} | ||
- run: datasets/populate | ||
|
||
build-lib: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-node@v4 | ||
with: | ||
node-version-file: .nvmrc | ||
cache: npm | ||
- run: npm ci | ||
- run: npm --workspace=discojs run build | ||
|
||
build-lib-web: | ||
needs: build-lib | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-node@v4 | ||
with: | ||
node-version-file: .nvmrc | ||
cache: npm | ||
- run: npm ci | ||
- run: npm run --workspace=discojs build | ||
- run: npm run --workspace=discojs-web build | ||
|
||
record-test-webapp: | ||
needs: [build-lib, build-lib-web, download-datasets] | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
lfs: true | ||
submodules: true | ||
- uses: actions/cache@v4 | ||
with: | ||
path: datasets | ||
key: datasets-${{ hashFiles('datasets/**') }} | ||
- uses: actions/setup-node@v4 | ||
with: | ||
node-version-file: .nvmrc | ||
cache: npm | ||
- run: npm ci | ||
- run: npm --workspace={discojs,discojs-web} run build | ||
- run: npm --workspace=webapp run test:unit | ||
- uses: cypress-io/github-action@v6 | ||
with: | ||
working-directory: webapp | ||
install: false | ||
start: npm start | ||
wait-on: 'http://localhost:8081' # Waits for above | ||
# Records to Cypress Cloud | ||
# https://docs.cypress.io/guides/cloud/projects#Set-up-a-project-to-record | ||
record: true | ||
env: | ||
VITE_SERVER_URL: http://server | ||
CYPRESS_RECORD_KEY: ${{ secrets.CYPRESS_RECORD_KEY }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import "@tensorflow/tfjs-node" | ||
import { AutoTokenizer } from "@xenova/transformers"; | ||
import { models, processing, Dataset } from "@epfml/discojs"; | ||
import { List } from "immutable"; | ||
|
||
async function main(): Promise<void> { | ||
const data = "Lorem ipsum dolor sit amet, consectetur adipis" | ||
const seed = 42 | ||
|
||
const config: models.GPTConfig = { | ||
modelType: 'gpt-nano', | ||
lr: 0.01, | ||
maxIter: 50, | ||
evaluateEvery:50, | ||
maxEvalBatches: 10, | ||
contextLength: 16, | ||
seed | ||
} | ||
|
||
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') | ||
|
||
const tokenDataset = new Dataset([data]) | ||
.map((text: string) => processing.tokenize(tokenizer, text)) | ||
.flatten() | ||
.batch(config.contextLength + 1, 1) | ||
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number]) | ||
.repeat() | ||
.batch(8); | ||
|
||
const model = new models.GPT(config) | ||
for await (const logs of model.train(tokenDataset, undefined)) { | ||
console.log(logs) | ||
} | ||
|
||
let tokens = processing.tokenize(tokenizer, "Lorem"); | ||
|
||
const maxNewTokens = 14 | ||
for (let n = 0; n < maxNewTokens; n++) { | ||
const next: number = (await model.predict( | ||
List.of(tokens), { seed }) | ||
).first(); | ||
tokens = tokens.push(next) | ||
} | ||
const generation = tokenizer.decode(tokens.toArray(), { skip_special_tokens: true }) | ||
console.log(generation) | ||
} | ||
|
||
// You can run this example with "npm run run_gpt" from this folder | ||
main().catch(console.error) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,26 @@ | ||
import * as fs from "node:fs/promises"; | ||
import * as readline from "node:readline/promises"; | ||
|
||
import createDebug from "debug"; | ||
import { createReadStream } from 'node:fs'; | ||
import { Dataset, Text } from "@epfml/discojs"; | ||
|
||
const debug = createDebug("discojs-node:loaders:text"); | ||
|
||
/** | ||
* Returns chunks of text. Use `minChunkSize` to ensure that | ||
* each chunk is bigger than the expected sequence length. | ||
* | ||
* @param path path to the text file to read | ||
* @returns a dataset of tokenized input and label sequences | ||
*/ | ||
export function load(path: string): Dataset<Text> { | ||
return new Dataset(async function* () { | ||
const input = (await fs.open(path)).createReadStream({ encoding: "utf8" }); | ||
// Create a stream to read the text file chunk by chunk | ||
const stream = createReadStream(path, { encoding: "utf8" }); | ||
for await (const chunk of stream) { | ||
if (typeof chunk !== 'string') | ||
throw new Error('Expected file stream to yield string') | ||
|
||
// `readline` is a bit overkill but seems standard | ||
// https://nodejs.org/api/readline.html#example-read-file-stream-line-by-line | ||
yield* readline.createInterface({ input, crlfDelay: Infinity }); | ||
debug("yield chunk of length: %o", chunk.length); | ||
yield chunk | ||
} | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.