Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: framework agnostic loaders #682

Merged
merged 24 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ npm -w cli start -- --help # or -h
The CLI can be used on several pre-defined tasks: titanic, simple-face and CIFAR10. In order
to understand how to add a new task have a look at [TASK.md](../docs/TASK.md).

Once a new task has been defined in `discojs`, it can be loaded in [data.ts](./src/data.ts) as it is already implemented for current tasks. There are currently [multiple classes](../discojs-node/src/dataset/data_loader) you can use to load data using Node.js and preprocess data: ImageLoader, TabularLoader and TextLoader.
Once a new task has been defined in `discojs`, it can be loaded in [data.ts](./src/data.ts) as it is already implemented for current tasks. There are currently [multiple classes](../discojs-node/src/loaders) you can use to load data using Node.js and preprocess data: loadImagesInDir, loadCSV and loadText.
Once a function to load data has been added, make sure to extend `getTaskData` in `data.ts`, which matches each task with it respective with data loading function.

The last thing to add is to add the task as a CLI argument in [args.ts](./src/args.ts) to the `supportedTasks` Map.
Expand Down
66 changes: 46 additions & 20 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { parse } from "ts-command-line-args";
import type * as tf from "@tensorflow/tfjs"
import * as tf from "@tensorflow/tfjs"
import { AutoTokenizer } from "@xenova/transformers";

import type { Task } from '@epfml/discojs'
import { fetchTasks, data, models, async_iterator, defaultTasks } from "@epfml/discojs";
import { NodeTextLoader, loadModelFromDisk } from '@epfml/discojs-node'
import { fetchTasks, models, async_iterator, defaultTasks, processing } from "@epfml/discojs";
import { loadModelFromDisk, loadText } from '@epfml/discojs-node'

import { Server } from "server";

Expand Down Expand Up @@ -41,6 +41,15 @@ const args = { ...defaultArgs, ...parsedArgs }
* Benchmark results are reported in https://github.com/epfml/disco/pull/659
*/

function intoTFGenerator<T extends tf.TensorContainer>(
iter: AsyncIterable<T>,
): tf.data.Dataset<T> {
// @ts-expect-error generator
return tf.data.generator(async function* () {
yield* iter;
});
}

async function main(args: Required<CLIArguments>): Promise<void> {
const { inference: benchmarkInference, modelType,
contextLength, batchSize, modelPath } = args
Expand All @@ -54,6 +63,10 @@ async function main(args: Required<CLIArguments>): Promise<void> {
const task = tasks.get('llm_task')
if (task === undefined) { throw new Error('task not found') }

const tokenizerName = task.trainingInformation.tokenizer
if (typeof tokenizerName !== 'string') throw Error('no tokenizer name specified in the task training information')
const tokenizer = await AutoTokenizer.from_pretrained(tokenizerName)

/**
* Training benchmark
*/
Expand All @@ -74,12 +87,35 @@ async function main(args: Required<CLIArguments>): Promise<void> {
// to make sure the dataset is batched and tokenized correctly
task.trainingInformation.batchSize = batchSize
task.trainingInformation.maxSequenceLength = contextLength
const dataset = await loadWikitextData(task)
const preprocessedDataset = dataset.train.preprocess().batch()
.dataset as tf.data.Dataset<{
xs: tf.Tensor2D;
ys: tf.Tensor3D;
}>;
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')

const maxLength = task.trainingInformation.maxSequenceLength ?? (tokenizer.model_max_length as number) + 1
// TODO will be easier when preproccessing is redone
const preprocessedDataset = intoTFGenerator(
dataset
.map((line) =>
processing.tokenizeAndLeftPad(line, tokenizer, maxLength),
)
.batch(batchSize)
.map((batch) =>
tf.tidy(() => ({
xs: tf.tensor2d(
batch.map((tokens) => tokens.slice(0, -1)).toArray(),
),
ys: tf.stack(
batch
.map(
(tokens) =>
tf.oneHot(
tokens.slice(1),
tokenizer.model.vocab.length + 1,
) as tf.Tensor2D,
)
.toArray(),
) as tf.Tensor3D,
})),
),
);

// Init and train the model
const model = new models.GPT(config)
Expand All @@ -101,8 +137,6 @@ async function main(args: Required<CLIArguments>): Promise<void> {
if (!(model instanceof models.GPT)){
throw new Error("Loaded model isn't a GPT model")
}
// Retrieve the tokenizer used during training
const tokenizer = await models.getTaskTokenizer(task)

// Benchmark parameters
const prompt = 'The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion,'
Expand All @@ -125,13 +159,5 @@ async function main(args: Required<CLIArguments>): Promise<void> {
})
}

async function loadWikitextData (task: Task): Promise<data.DataSplit> {
const loader = new NodeTextLoader(task)
const dataSplit: data.DataSplit = {
train: await data.TextData.init(await loader.load('../datasets/wikitext/wiki.train.tokens', {shuffle: true}), task)
}
return dataSplit
}

// You can run this example with "npm start" from this folder
main(args).catch(console.error)
7 changes: 5 additions & 2 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// speed things up TODO how to avoid the need to import it
import "@tensorflow/tfjs-node"

import { List, Range } from 'immutable'
import fs from 'node:fs/promises'

import type { data, RoundLogs, Task, TaskProvider } from '@epfml/discojs'
import type { RoundLogs, Task, TaskProvider, TypedLabeledDataset } from '@epfml/discojs'
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'
import { Server } from 'server'

Expand All @@ -18,7 +21,7 @@ async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
async function runUser(
task: Task,
url: URL,
data: data.DataSplit,
data: TypedLabeledDataset,
): Promise<List<RoundLogs>> {
const trainingScheme = task.trainingInformation.scheme
const aggregator = aggregators.getAggregator(task)
Expand Down
102 changes: 40 additions & 62 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,76 +1,54 @@
import { Range, Repeat } from 'immutable'
import fs from 'node:fs/promises'
import path from 'node:path'
import path from "node:path";

import type { Task, data } from '@epfml/discojs'
import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node'
import type { Dataset, Image, Task, TypedLabeledDataset } from "@epfml/discojs";
import { loadCSV, loadImagesInDir } from "@epfml/discojs-node";
import { Repeat } from "immutable";

async function simplefaceData (task: Task): Promise<data.DataSplit> {
const dir = '../datasets/simple_face/'
const youngFolder = dir + 'child/'
const adultFolder = dir + 'adult/'
async function loadSimpleFaceData(): Promise<Dataset<[Image, string]>> {
const folder = path.join("..", "datasets", "simple_face");

const youngFiles = (await fs.readdir(youngFolder)).map(file => path.join(youngFolder, file))
const adultFiles = (await fs.readdir(adultFolder)).map(file => path.join(adultFolder, file))
const images = youngFiles.concat(adultFiles)
const [adults, childs]: Dataset<[Image, string]>[] = [
(await loadImagesInDir(path.join(folder, "adult"))).zip(Repeat("adult")),
(await loadImagesInDir(path.join(folder, "child"))).zip(Repeat("child")),
];

const youngLabels = youngFiles.map(_ => 'child')
const oldLabels = adultFiles.map(_ => 'adult')
const labels = youngLabels.concat(oldLabels)
return await new NodeImageLoader(task).loadAll(images, { labels })
return adults.chain(childs);
}

async function cifar10Data (cifar10: Task): Promise<data.DataSplit> {
const dir = '../datasets/CIFAR10/'
const files = (await fs.readdir(dir)).map((file) => path.join(dir, file))
const labels = Repeat('airplane', 24).toArray() // TODO read labels in csv
return await new NodeImageLoader(cifar10).loadAll(files, { labels })
}

async function lusCovidData (lusCovid: Task): Promise<data.DataSplit> {
const dir = '../datasets/lus_covid/'
const covid_pos = dir + 'COVID+'
const covid_neg = dir + 'COVID-'
const files_pos = (await fs.readdir(covid_pos)).filter(file => file.endsWith('.png')).map(file => path.join(covid_pos, file))
const label_pos = Range(0, files_pos.length).map(_ => 'COVID-Positive')

const files_neg = (await fs.readdir(covid_neg)).filter(file => file.endsWith('.png')).map(file => path.join(covid_neg, file))
const label_neg = Range(0, files_neg.length).map(_ => 'COVID-Negative')

const files = files_pos.concat(files_neg)
const labels = label_pos.concat(label_neg).toArray()

const dataConfig = { labels, shuffle: true, validationSplit: 0.1, channels: 3 }
return await new NodeImageLoader(lusCovid).loadAll(files, dataConfig)
}
async function loadLusCovidData(): Promise<Dataset<[Image, string]>> {
const folder = path.join("..", "datasets", "lus_covid");

async function titanicData (titanic: Task): Promise<data.DataSplit> {
const dir = '../datasets/titanic_train.csv'
const [positive, negative]: Dataset<[Image, string]>[] = [
(await loadImagesInDir(path.join(folder, "COVID+"))).zip(
Repeat("COVID-Positive"),
),
(await loadImagesInDir(path.join(folder, "COVID-"))).zip(
Repeat("COVID-Negative"),
),
];

const data = await (new NodeTabularLoader(titanic, ',').loadAll(
['file://'.concat(dir)],
{
features: titanic.trainingInformation?.inputColumns,
labels: titanic.trainingInformation?.outputColumns,
shuffle: false
}
))
return data
return positive.chain(negative);
}

export async function getTaskData (task: Task): Promise<data.DataSplit> {
export async function getTaskData(task: Task): Promise<TypedLabeledDataset> {
switch (task.id) {
case 'simple_face':
return await simplefaceData(task)
case 'titanic':
return await titanicData(task)
case 'cifar10':
return await cifar10Data(task)
case 'lus_covid':
return await lusCovidData(task)
case 'YOUR CUSTOM TASK HERE':
throw new Error('YOUR CUSTOM FUNCTION HERE')
case "simple_face":
return ["image", await loadSimpleFaceData()];
case "titanic":
return [
"tabular",
loadCSV(path.join("..", "datasets", "titanic_train.csv")),
];
case "cifar10":
return [
"image",
(await loadImagesInDir(path.join("..", "datasets", "CIFAR10"))).zip(
Repeat("cat"),
),
];
tharvik marked this conversation as resolved.
Show resolved Hide resolved
case "lus_covid":
return ["image", await loadLusCovidData()];
default:
throw new Error(`Data loader for ${task.id} not implemented.`)
throw new Error(`Data loader for ${task.id} not implemented.`);
}
}
5 changes: 4 additions & 1 deletion discojs-node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
"dependencies": {
"@epfml/discojs": "*",
"@koush/wrtc": "0.5",
"@tensorflow/tfjs-node": "4"
"@tensorflow/tfjs-node": "4",
"csv-parse": "5",
"sharp": "0.33"
},
"devDependencies": {
"@types/node": "22",
"nodemon": "3",
"tmp-promise": "3",
"ts-node": "10"
}
}
Loading