diff --git a/dist/js-pytorch-browser.js b/dist/js-pytorch-browser.js index 5d2b6d8..d851377 100644 --- a/dist/js-pytorch-browser.js +++ b/dist/js-pytorch-browser.js @@ -1913,8 +1913,33 @@ export class MSELoss extends Module { return loss; } } +/** + * Saves the model to a JSON file. + * @param {Module} model - Model to be saved in JSON file. + * @param {string} file - JSON file. + */ function save(model, file) { - const data = JSON.stringify(model); + /** + * Filters object, returning 'null' instead of 'value' for certain keys. + * @param {object} obj - Objects with keys and values that we have to filter. + * @returns {object} Filtered object. + */ + function recursiveReplacer(obj){ + let result = {}; + for (var x in obj) { + if (x !== "forwardKernel" && x !== "backwardKernelA" && x !== "backwardKernelB" && x !== "gpu") { + if (typeof obj[x] === 'object' && !Array.isArray(obj[x])) { + result[x] = recursiveReplacer(obj[x]); + } else { + result[x] = obj[x]; + } + } else { + result[x] = null; + } + } + return result + } + const data = JSON.stringify(recursiveReplacer(model)); fs.writeFileSync(file, data); } function load(model, file) { diff --git a/src/layers.ts b/src/layers.ts index 1ed5361..6c1f9c4 100644 --- a/src/layers.ts +++ b/src/layers.ts @@ -580,7 +580,27 @@ export class MSELoss extends Module { * @param {string} file - JSON file. */ export function save(model: Module, file: string) { - const data = JSON.stringify(model); + /** + * Filters object, returning 'null' instead of 'value' for certain keys. + * @param {object} obj - Objects with keys and values that we have to filter. + * @returns {object} Filtered object. + */ + function recursiveReplacer(obj: { [key: string]: any; }): { [key: string]: any; }{ + let result: { [key: string]: any; } = {}; + for (var x in obj) { + if (x !== "forwardKernel" && x !== "backwardKernelA" && x !== "backwardKernelB" && x !== "gpu") { + if (typeof obj[x] === 'object' && !Array.isArray(obj[x])) { + result[x] = recursiveReplacer(obj[x]); + } else { + result[x] = obj[x]; + } + } else { + result[x] = null; + } + } + return result + } + const data = JSON.stringify(recursiveReplacer(model)); fs.writeFileSync(file, data); }