diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts index e1294ad35..7e72bd501 100644 --- a/discojs/src/models/tfjs.ts +++ b/discojs/src/models/tfjs.ts @@ -16,7 +16,7 @@ import { EpochLogs } from './logs.js' type Serialized = [D, tf.io.ModelArtifacts]; /** TensorFlow JavaScript model with standard training */ -export class TFJS extends Model { +export class TFJS extends Model { /** Wrap the given trainable model */ constructor ( public readonly datatype: D, @@ -168,7 +168,7 @@ export class TFJS extends Model { return ret } - static async deserialize([ + static async deserialize([ datatype, artifacts, ]: Serialized): Promise> { @@ -259,23 +259,6 @@ export class TFJS extends Model { ys: tf.stack(b.map(([_, output]) => tf.tensor1d([output])).toArray()), })); } - case "text": { - // cast as typescript doesn't reduce generic type - const b = batch as Batched; - - return { - xs: tf.stack( - b.map(([line]) => tf.tensor1d(line.toArray())).toArray(), - ), - ys: tf.stack( - b - .map(([line, next]) => - tf.oneHot(line.shift().push(next).toArray(), outputSize), - ) - .toArray(), - ), - }; - } } const _: never = this.datatype; @@ -311,14 +294,6 @@ export class TFJS extends Model { ), ); } - case "text": { - // cast as typescript doesn't reduce generic type - const b = batch as Batched; - - return tf.stack( - b.map((line) => tf.tensor1d(line.toArray())).toArray(), - ) - } } const _: never = this.datatype; diff --git a/discojs/src/serialization/model.spec.ts b/discojs/src/serialization/model.spec.ts index 7e8df69b1..39b4f13e2 100644 --- a/discojs/src/serialization/model.spec.ts +++ b/discojs/src/serialization/model.spec.ts @@ -35,7 +35,9 @@ describe('serialization', () => { const decoded = await serialization.model.decode(encoded) expect(decoded).to.be.an.instanceof(models.TFJS); - expect((decoded as models.TFJS).datatype).to.equal("image") + expect((decoded as models.TFJS<"image" | "tabular">).datatype).to.equal( + "image", + ); assert.sameDeepOrderedMembers( await getRawWeights(model), await getRawWeights(decoded) diff --git a/discojs/src/serialization/model.ts b/discojs/src/serialization/model.ts index 6a120bd6f..e7f63fbc3 100644 --- a/discojs/src/serialization/model.ts +++ b/discojs/src/serialization/model.ts @@ -61,11 +61,10 @@ export async function decode(encoded: unknown): Promise> { ); const [rawDatatype, rawModel] = raw.slice(1) as unknown[]; - let datatype: DataType; + let datatype; switch (rawDatatype) { case "image": case "tabular": - case "text": datatype = rawDatatype; break; default: diff --git a/webapp/src/components/task_creation_form/TaskForm.vue b/webapp/src/components/task_creation_form/TaskForm.vue index 1368efb96..30332b09d 100644 --- a/webapp/src/components/task_creation_form/TaskForm.vue +++ b/webapp/src/components/task_creation_form/TaskForm.vue @@ -279,10 +279,20 @@ const onSubmit = async (rawTask: any): Promise => { let model try { - model = new models.TFJS( - task.trainingInformation.dataType, - await tf.loadLayersModel(tf.io.browserFiles(modelFiles.value.toArray())), - ); + switch (task.trainingInformation.dataType) { + case "image": + case "tabular": + model = new models.TFJS( + task.trainingInformation.dataType, + await tf.loadLayersModel( + tf.io.browserFiles(modelFiles.value.toArray()), + ), + ); + break; + case "text": + toaster.error("Currently no support of TFJS text model"); + return; + } } catch (e) { debug("while loading model:%o", e); toaster.error('Model loading failed');