Skip to content

Commit

Permalink
discojs/model/tfjs: drop text support
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Oct 4, 2024
1 parent 9f10e3e commit 0ac6c29
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 34 deletions.
29 changes: 2 additions & 27 deletions discojs/src/models/tfjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { EpochLogs } from './logs.js'
type Serialized<D extends DataType> = [D, tf.io.ModelArtifacts];

/** TensorFlow JavaScript model with standard training */
export class TFJS<D extends DataType> extends Model<D> {
export class TFJS<D extends "image" | "tabular"> extends Model<D> {
/** Wrap the given trainable model */
constructor (
public readonly datatype: D,
Expand Down Expand Up @@ -168,7 +168,7 @@ export class TFJS<D extends DataType> extends Model<D> {
return ret
}

static async deserialize<D extends DataType>([
static async deserialize<D extends "image" | "tabular">([
datatype,
artifacts,
]: Serialized<D>): Promise<TFJS<D>> {
Expand Down Expand Up @@ -259,23 +259,6 @@ export class TFJS<D extends DataType> extends Model<D> {
ys: tf.stack(b.map(([_, output]) => tf.tensor1d([output])).toArray()),
}));
}
case "text": {
// cast as typescript doesn't reduce generic type
const b = batch as Batched<ModelEncoded["text"]>;

return {
xs: tf.stack(
b.map(([line]) => tf.tensor1d(line.toArray())).toArray(),
),
ys: tf.stack(
b
.map(([line, next]) =>
tf.oneHot(line.shift().push(next).toArray(), outputSize),
)
.toArray(),
),
};
}
}

const _: never = this.datatype;
Expand Down Expand Up @@ -311,14 +294,6 @@ export class TFJS<D extends DataType> extends Model<D> {
),
);
}
case "text": {
// cast as typescript doesn't reduce generic type
const b = batch as Batched<ModelEncoded["text"][0]>;

return tf.stack(
b.map((line) => tf.tensor1d(line.toArray())).toArray(),
)
}
}

const _: never = this.datatype;
Expand Down
4 changes: 3 additions & 1 deletion discojs/src/serialization/model.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ describe('serialization', () => {
const decoded = await serialization.model.decode(encoded)

expect(decoded).to.be.an.instanceof(models.TFJS);
expect((decoded as models.TFJS<DataType>).datatype).to.equal("image")
expect((decoded as models.TFJS<"image" | "tabular">).datatype).to.equal(
"image",
);
assert.sameDeepOrderedMembers(
await getRawWeights(model),
await getRawWeights(decoded)
Expand Down
3 changes: 1 addition & 2 deletions discojs/src/serialization/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ export async function decode(encoded: unknown): Promise<Model<DataType>> {
);
const [rawDatatype, rawModel] = raw.slice(1) as unknown[];

let datatype: DataType;
let datatype;
switch (rawDatatype) {
case "image":
case "tabular":
case "text":
datatype = rawDatatype;
break;
default:
Expand Down
18 changes: 14 additions & 4 deletions webapp/src/components/task_creation_form/TaskForm.vue
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,20 @@ const onSubmit = async (rawTask: any): Promise<void> => {
let model
try {
model = new models.TFJS(
task.trainingInformation.dataType,
await tf.loadLayersModel(tf.io.browserFiles(modelFiles.value.toArray())),
);
switch (task.trainingInformation.dataType) {
case "image":
case "tabular":
model = new models.TFJS(
task.trainingInformation.dataType,
await tf.loadLayersModel(
tf.io.browserFiles(modelFiles.value.toArray()),
),
);
break;
case "text":
toaster.error("Currently no support of TFJS text model");
return;
}
} catch (e) {
debug("while loading model:%o", e);
toaster.error('Model loading failed');
Expand Down

0 comments on commit 0ac6c29

Please sign in to comment.