Skip to content

Commit

Permalink
fix save and load for models with gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardoleao052 committed Oct 21, 2024
1 parent 9cc399e commit 9b3e830
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
27 changes: 26 additions & 1 deletion dist/js-pytorch-browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -1913,8 +1913,33 @@ export class MSELoss extends Module {
return loss;
}
}
/**
* Saves the model to a JSON file.
* @param {Module} model - Model to be saved in JSON file.
* @param {string} file - JSON file.
*/
function save(model, file) {
const data = JSON.stringify(model);
/**
* Filters object, returning 'null' instead of 'value' for certain keys.
* @param {object} obj - Objects with keys and values that we have to filter.
* @returns {object} Filtered object.
*/
function recursiveReplacer(obj){
let result = {};
for (var x in obj) {
if (x !== "forwardKernel" && x !== "backwardKernelA" && x !== "backwardKernelB" && x !== "gpu") {
if (typeof obj[x] === 'object' && !Array.isArray(obj[x])) {
result[x] = recursiveReplacer(obj[x]);
} else {
result[x] = obj[x];
}
} else {
result[x] = null;
}
}
return result
}
const data = JSON.stringify(recursiveReplacer(model));
fs.writeFileSync(file, data);
}
function load(model, file) {
Expand Down
22 changes: 21 additions & 1 deletion src/layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,27 @@ export class MSELoss extends Module {
* @param {string} file - JSON file.
*/
export function save(model: Module, file: string) {
const data = JSON.stringify(model);
/**
* Filters object, returning 'null' instead of 'value' for certain keys.
* @param {object} obj - Objects with keys and values that we have to filter.
* @returns {object} Filtered object.
*/
function recursiveReplacer(obj: { [key: string]: any; }): { [key: string]: any; }{
let result: { [key: string]: any; } = {};
for (var x in obj) {
if (x !== "forwardKernel" && x !== "backwardKernelA" && x !== "backwardKernelB" && x !== "gpu") {
if (typeof obj[x] === 'object' && !Array.isArray(obj[x])) {
result[x] = recursiveReplacer(obj[x]);
} else {
result[x] = obj[x];
}
} else {
result[x] = null;
}
}
return result
}
const data = JSON.stringify(recursiveReplacer(model));
fs.writeFileSync(file, data);
}

Expand Down

0 comments on commit 9b3e830

Please sign in to comment.