Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let users join federated training mid session #741

Merged
merged 98 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
eb5b533
Fix merge conflicts
JulienVig Aug 12, 2024
17ec537
Simplify isValidContribution flow
JulienVig Aug 6, 2024
dc43227
Enable server to send latest global model to stale participants
JulienVig Aug 6, 2024
7044836
Make client rely on the aggregator's round rather than the trainer's …
JulienVig Aug 6, 2024
dd162fd
Rm unused server logs
JulienVig Aug 7, 2024
b25bfc9
Prevent overflowing text for narrow screens
JulienVig Aug 7, 2024
6718fc9
Enable multiple tries for clients to receive a server update
JulienVig Aug 7, 2024
67ecb81
Fix node ans browser ws API incompatibility
JulienVig Aug 7, 2024
f87b60e
Make client disconnect gracefully when leaving training session
JulienVig Aug 7, 2024
3d1791a
Add a flag with min participant threshold when client connects to server
JulienVig Aug 12, 2024
6f450c8
Refactor server router following Express conventions
JulienVig Aug 12, 2024
7d20de2
Rename 'feai' and 'deai' routes into 'federated' and 'decentralized'
JulienVig Aug 12, 2024
f8f07ca
Split controller abstraction following MVC pattern
JulienVig Aug 12, 2024
0a2f5b6
Create one training controller per task rather than one for all tasks
JulienVig Aug 12, 2024
10088a8
Add a space before DISCOllaboratives
JulienVig Aug 13, 2024
15c2730
Make info toaster less aggressive
JulienVig Aug 13, 2024
fd554bd
Enable waiting for more participants when below a minimum number of c…
JulienVig Aug 14, 2024
6fdf569
Rename Base class and files into FederatedClient and DecentralizedClient
JulienVig Aug 14, 2024
713ab20
Rename Base abstract class into Client
JulienVig Aug 14, 2024
e323bc9
Let client and trainer update the training status through the Logger …
JulienVig Aug 14, 2024
0800e97
Allow IconCard to grow with fill-space prop
JulienVig Aug 14, 2024
67a561b
Display training status while training
JulienVig Aug 14, 2024
f77f3eb
Rename controller files
JulienVig Aug 14, 2024
7af9cdc
Make client use MockLogger by default
JulienVig Aug 14, 2024
32081bb
Update tests w new training buttons
JulienVig Aug 14, 2024
a86f7b8
federated training now requires at least 2 users
JulienVig Aug 14, 2024
6c2d625
Replace console by debug statements
JulienVig Aug 19, 2024
aec8cc1
Update debug level
JulienVig Aug 19, 2024
ad023cd
Add error context
JulienVig Aug 19, 2024
b6e1709
Rm commented line
JulienVig Aug 19, 2024
c9ed1e8
simplify boolean prop
JulienVig Aug 19, 2024
788f6e4
Rm ClientDisconnected message logic
JulienVig Aug 19, 2024
aabd22b
Rm unused handle arguments
JulienVig Aug 19, 2024
a836481
Make waitingForMoreParticipants a getter rather than an attribute
JulienVig Aug 19, 2024
4fbcb0a
Add more error info
JulienVig Aug 19, 2024
26dd742
Use trainingInformation minNbOfParticipant for both federated and dec…
JulienVig Aug 19, 2024
7f9d374
Use immutable map and minNbOfParticipant
JulienVig Aug 19, 2024
e29f7bc
clean and inline code
JulienVig Aug 19, 2024
6599b09
remove unnecessary process.nextTick
JulienVig Aug 19, 2024
d4f8f43
Clarify process.nextTick usage
JulienVig Aug 19, 2024
11c35c1
Add "router" suffix to routes filenames
JulienVig Aug 20, 2024
92b51a6
Refactor TasksAndModels into TaskInitializer and add doc
JulienVig Aug 20, 2024
c3e7082
Add minNbOfParticipants field to trainingInformation
JulienVig Aug 20, 2024
90e761a
Add comment for #718
JulienVig Aug 20, 2024
0f1f41a
Simplify federated communication logic
JulienVig Aug 20, 2024
65e9869
Rename MockLogger to DummyLogger
JulienVig Aug 20, 2024
6a6cf8a
Add default value for fillSpace
JulienVig Aug 20, 2024
b6c8c64
Renaming Local to LocalClient
JulienVig Aug 20, 2024
9569b42
Make Disco and Client emit RoundStatus
JulienVig Aug 20, 2024
fd92d19
Rm dummy logger
JulienVig Aug 20, 2024
a8a1c69
Rm unused setStatus from logger
JulienVig Aug 20, 2024
446777b
Add 'Retrieving peers information' round status to decentralized
JulienVig Aug 20, 2024
08ea626
Rename cleanupTrainingSessionFn to cleanupDisco
JulienVig Aug 20, 2024
cadb160
Keep models encoded rather than decoding them to re-encode them right…
JulienVig Aug 21, 2024
ad01291
Export EncodedModel and EncodedWeights
JulienVig Aug 21, 2024
117c972
Use EncodedModel import
JulienVig Aug 21, 2024
4caa656
Simplify import
JulienVig Aug 21, 2024
e4ba98e
split NewNodeInfo between federated and decentralized to enable clien…
JulienVig Aug 21, 2024
3ad5b2e
Rm initTask
JulienVig Aug 21, 2024
30c1314
send latest weight to clients joining task
JulienVig Aug 21, 2024
192ef04
overwrite working model with model fetched by client.connect
JulienVig Aug 21, 2024
cd3902e
Use Disco constructor
JulienVig Aug 21, 2024
a4f2149
simplify disco cleanup
JulienVig Aug 21, 2024
b08a93a
Fix round increment
JulienVig Aug 28, 2024
5e48c80
Fix debug statement
JulienVig Aug 28, 2024
6952bbe
Improve debug statements
JulienVig Aug 28, 2024
33959cb
Merge with develop
JulienVig Aug 28, 2024
0edbf04
Expect client.connect to return the latest model
JulienVig Aug 29, 2024
57fb179
clean code and add doc
JulienVig Aug 29, 2024
fd3c0f0
Save server round and display waiting status immediately upon receiving
JulienVig Aug 29, 2024
1a7a514
Test client status emission
JulienVig Aug 29, 2024
c9c40c7
Send nb of participants upon client connection
JulienVig Aug 29, 2024
10ae9fd
Clarify client event cycle
JulienVig Aug 29, 2024
ad4c2b3
Fix typo
JulienVig Aug 29, 2024
804ed49
Fix disco object undefined during cleanup
JulienVig Aug 29, 2024
a5a40a6
increase beforeEach timeout
JulienVig Aug 29, 2024
3b62c55
Increase wait time for status to update
JulienVig Aug 29, 2024
6a40184
Increase wait time for status to update
JulienVig Aug 29, 2024
f06dbae
Debug github action test
JulienVig Aug 29, 2024
9d902ae
Show debug statements in github actions
JulienVig Aug 29, 2024
5864334
debug github actions test
JulienVig Aug 29, 2024
44ce301
Reject contribution when not enough participants
JulienVig Aug 29, 2024
21314a3
Decrease wait time for status update
JulienVig Aug 29, 2024
97b35c0
Don't aggregate when below task minNbOfParticipants
JulienVig Aug 29, 2024
699645b
Rm verbose debug
JulienVig Sep 2, 2024
b236578
Fix attribute increment
JulienVig Sep 2, 2024
94bfd7c
Rm unused import
JulienVig Sep 2, 2024
a262b34
Rm useless assertion
JulienVig Sep 2, 2024
25d9dfb
Clean chai expects
JulienVig Sep 2, 2024
f02df76
make emit a private method
JulienVig Sep 2, 2024
9f926a8
Expose a task getter
JulienVig Sep 2, 2024
fac2708
Rely on task initializer set rather than duplicate it in router
JulienVig Sep 2, 2024
03c989a
Export EventEmitter
JulienVig Sep 2, 2024
4285832
Support async callbacks
JulienVig Sep 2, 2024
98f1df5
Rm Digest feature
JulienVig Sep 2, 2024
954b6e1
Fix event emitter callbacks
JulienVig Sep 2, 2024
6ea8dbf
Rm promise type callbacks
JulienVig Sep 2, 2024
e6f09e1
Prevent duplicate round increment
JulienVig Sep 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint-test-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ jobs:
cache: npm
- run: npm ci
- run: npm --workspace={discojs,discojs-node} run build
- run: npm --workspace=server test
- run: DEBUG='server:controllers:federated|discojs:client:federated' npm --workspace=server test
JulienVig marked this conversation as resolved.
Show resolved Hide resolved

test-webapp:
needs: [build-lib, build-lib-web, download-datasets]
Expand Down
2 changes: 1 addition & 1 deletion cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async function runUser(
const trainingScheme = task.trainingInformation.scheme
const aggregator = aggregators.getAggregator(task)
const client = clients.getClient(trainingScheme, url, task, aggregator)
const disco = await Disco.fromTask(task, client, { scheme: trainingScheme });
const disco = new Disco(task, client, { scheme: trainingScheme });

const logs = List(await arrayFromAsync(disco.trainByRound(data)));
await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish
Expand Down
11 changes: 8 additions & 3 deletions discojs-web/src/memory/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
* The working model is loaded from IndexedDB for training (model.fit) only.
*/
import { Map } from 'immutable'
import createDebug from "debug"
import * as tf from '@tensorflow/tfjs'

import type { Model, ModelInfo, ModelSource } from '@epfml/discojs'
import { Memory, models } from '@epfml/discojs'

const debug = createDebug('discojs-web:memory')

export class IndexedDB extends Memory {
override getModelMemoryPath (source: ModelSource): string {
if (typeof source === 'string') {
Expand Down Expand Up @@ -100,7 +103,8 @@ export class IndexedDB extends Memory {
modelInfo['tensorBackend'] = 'gpt'
includeOptimizer = false // true raises an error
} else {
throw new Error('unknown model type')
debug('unknown working model type %o', model)
throw new Error(`unknown model type while updating working model`)
}
const indexedDBURL = this.getModelMemoryPath(modelInfo)
await model.extract().save(indexedDBURL, { includeOptimizer })
Expand Down Expand Up @@ -137,8 +141,9 @@ export class IndexedDB extends Memory {
} else if (model instanceof models.GPT) {
modelInfo['tensorBackend'] = 'gpt'
includeOptimizer = false // true raises an error
} else {
throw new Error('unknown model type')
} else {
debug('unknown saved model type %o', model)
throw new Error('unknown model type while saving model')
}
const indexedDBURL = this.getModelMemoryPath(modelInfo)
await model.extract().save(indexedDBURL, { includeOptimizer })
Expand Down
38 changes: 33 additions & 5 deletions discojs/src/aggregator/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,34 @@ export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {

/**
* Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
* The aggregation round is increased whenever a new global model is obtained and local models are updated.
* Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation
* which requires multiple steps to obtain a global model)
* The contribution will be aggregated during the next aggregation step.
* @param nodeId The node's id
* @param contribution The node's contribution
* @param round For which aggregation round the contribution was made
* @param communicationRound For which communication round the contribution was made
* @param round aggregation round of the contribution was made
* @param communicationRound communication round the contribution was made within the aggregation round
* @returns boolean, true if the contribution has been successfully taken into account or False if it has been rejected
*/
abstract add (nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean

/**
* Evaluates whether a given participant contribution can be used in the current aggregation round
* the boolean returned by `this.add` is obtained via `this.isValidContribution`
*/
isValidContribution(nodeId: client.NodeID, round: number): boolean {
if (!this.nodes.has(nodeId)) {
debug("Contribution rejected because node id is not registered")
return false;
}
if (!this.isWithinRoundCutoff(round)) {
debug(`Contribution rejected because round ${round} is not within round cutoff`)
return false;
}
return true
}

/**
* Performs an aggregation step over the received node contributions.
* Must store the aggregation's result in the aggregator's result promise.
Expand All @@ -101,16 +121,16 @@ export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {
log (step: AggregationStep, from?: client.NodeID): void {
switch (step) {
case AggregationStep.ADD:
debug(`adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
debug(`Adding contribution from node ${from ?? '"unknown"'} for aggregation round ${this.round} and communication round ${this.communicationRound}`);
break
case AggregationStep.UPDATE:
if (from === undefined) {
return
}
debug(`updating contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`)
debug(`Updating contribution from node ${from} for aggregation round ${this.round} and communication round ${this.communicationRound}`)
break
case AggregationStep.AGGREGATE:
debug(`buffer full, aggregating weights for round (${this.communicationRound}, ${this.round})`)
debug(`Buffer is full. Aggregating weights for round aggregation round ${this.round} and communication round ${this.communicationRound}`)
break
default: {
const _: never = step
Expand All @@ -134,6 +154,14 @@ export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {
return false
}

/**
* Remove a node's id from the set of active nodes.
* @param nodeId The node to be removed
*/
removeNode (nodeId: client.NodeID): void{
this._nodes = this._nodes.delete(nodeId)
}

/**
* Overwrites the current set of active nodes with the given one. A node represents
* an active neighbor peer/client within the network, whom we are communicating with
Expand Down
17 changes: 12 additions & 5 deletions discojs/src/aggregator/mean.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type ThresholdType = 'relative' | 'absolute'
export class MeanAggregator extends Aggregator<WeightsContainer> {
readonly #threshold: number;
readonly #thresholdType: ThresholdType;
#minNbOfParticipants: number | undefined;

/**
* Create a mean aggregator that averages all weight updates received when a specified threshold is met.
Expand Down Expand Up @@ -77,6 +78,11 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {

/** Checks whether the contributions buffer is full. */
override isFull(): boolean {
// Make sure that we are over the minimum number of participants
// if specified
if (this.#minNbOfParticipants !== undefined &&
this.nodes.size < this.#minNbOfParticipants) return false

const thresholdValue =
this.#thresholdType == 'relative'
? this.#threshold * this.nodes.size
Expand All @@ -85,6 +91,10 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
return (this.contributions.get(0)?.size ?? 0) >= thresholdValue;
}

set minNbOfParticipants(minNbOfParticipants: number) {
this.#minNbOfParticipants = minNbOfParticipants
}

override add(
nodeId: client.NodeID,
contribution: WeightsContainer,
Expand All @@ -94,11 +104,7 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
if (currentContributions !== 0)
throw new Error("only a single communication round");

if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) {
if (!this.nodes.has(nodeId)) debug(`contribution rejected because node ${nodeId} is not registered`);
if (!this.isWithinRoundCutoff(round)) debug(`contribution rejected because round ${round} is not within cutoff`);
return false;
}
if (!this.isValidContribution(nodeId, round)) return false

this.log(
this.contributions.hasIn([0, nodeId])
Expand All @@ -122,6 +128,7 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
this.log(AggregationStep.AGGREGATE);

const result = aggregation.avg(currentContributions.values());
this._round = this._round++
JulienVig marked this conversation as resolved.
Show resolved Hide resolved
this.emit('aggregation', result);
}

Expand Down
3 changes: 1 addition & 2 deletions discojs/src/aggregator/secure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
throw new Error("requires communication round to be 0 or 1");
}

if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
return false;
if (!this.isValidContribution(nodeId, round)) return false

this.log(
this.contributions.hasIn([communicationRound, nodeId])
Expand Down
45 changes: 17 additions & 28 deletions discojs/src/client/base.ts → discojs/src/client/client.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import axios from 'axios'

import type { Model, Task, WeightsContainer } from '../index.js'
import type { Model, Task, WeightsContainer, RoundStatus } from '../index.js'
import { serialization } from '../index.js'
import type { NodeID } from './types.js'
import type { EventConnection } from './event_connection.js'
import type { Aggregator } from '../aggregator/index.js'
import { EventEmitter } from '../utils/event_emitter.js'

/**
* Main, abstract, class representing a Disco client in a network, which handles
* communication with other nodes, be it peers or a server.
*/
export abstract class Base {
export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
/**
* Own ID provided by the network's server.
*/
Expand All @@ -25,24 +26,21 @@ export abstract class Base {
protected aggregationResult?: Promise<WeightsContainer>

constructor (
/**
* The network server's URL to connect to.
*/
public readonly url: URL,
/**
* The client's corresponding task.
*/
public readonly task: Task,
/**
* The client's aggregator.
*/
public readonly aggregator: Aggregator
) {}
public readonly url: URL, // The network server's URL to connect to
public readonly task: Task, // The client's corresponding task
public readonly aggregator: Aggregator,
) {
super()
}

/**
* Handles the connection process from the client to any sort of network server.
* This method is overriden by the federated and decentralized clients
* By default, it fetches and returns the server's base model
*/
async connect (): Promise<void> {}
async connect(): Promise<Model> {
return this.getLatestModel()
}

/**
* Handles the disconnection process of the client from any sort of network server.
Expand All @@ -66,24 +64,15 @@ export abstract class Base {

/**
* Communication callback called at the beginning of every training round.
* @param _weights The most recent local weight updates
* @param _round The current training round
*/
async onRoundBeginCommunication(
_weights: WeightsContainer,
_round: number,
): Promise<void> {}
abstract onRoundBeginCommunication(): Promise<void>;

/**
* Communication callback called the end of every training round.
* @param _weights The most recent local weight updates
* @param _round The current training round
* @param weights The local weight update resulting for the current local training round
* @returns aggregated weights or the local weights upon error
*/
abstract onRoundEndCommunication(
_weights: WeightsContainer,
_round: number,
): Promise<WeightsContainer>;
abstract onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;

// Number of contributors to a collaborative session
// If decentralized, it should be the number of peers
Expand Down
Loading