diff --git a/gui/src/app/tinystan/index.ts b/gui/src/app/tinystan/index.ts index 053ba895..cda270e5 100644 --- a/gui/src/app/tinystan/index.ts +++ b/gui/src/app/tinystan/index.ts @@ -160,7 +160,9 @@ export default class StanModel { throw new Error(err_msg); } - private encodeInits(inits: string | StanVariableInputs | string[] | StanVariableInputs[]): cstr { + private encodeInits( + inits: string | StanVariableInputs | string[] | StanVariableInputs[], + ): cstr { if (Array.isArray(inits)) { return this.encodeString( inits.map(i => string_safe_jsonify(i)).join(this.sep), @@ -173,7 +175,7 @@ export default class StanModel { private withModel( data: string | StanVariableInputs, seed: number, - f: (model: model_ptr) => T, + f: (model: model_ptr, deferredFree: (p: ptr | cstr) => void) => T, ): T { const data_ptr = this.encodeString(string_safe_jsonify(data)); const err_ptr = this.m._malloc(4); @@ -184,9 +186,14 @@ export default class StanModel { this.handleError(err_ptr); } this.m._free(err_ptr); + + const ptrs: (ptr | cstr)[] = []; + const deferredFree = (p: ptr | cstr) => ptrs.push(p); + try { - return f(model); + return f(model, deferredFree); } finally { + ptrs.forEach(p => this.m._free(p)); this.m._tinystan_destroy_model(model); } } @@ -234,7 +241,7 @@ export default class StanModel { seed_ = Math.floor(Math.random() * (2 ^ 32)); } - return this.withModel(data, seed_, model => { + return this.withModel(data, seed_, (model, deferredFree) => { // Get the parameter names const rawParamNames = this.m.UTF8ToString( this.m._tinystan_model_param_names(model), @@ -265,17 +272,23 @@ export default class StanModel { num_chains * free_params * Float64Array.BYTES_PER_ELEMENT, ); } + deferredFree(metric_out); + + const inits_ptr = this.encodeInits(inits); + deferredFree(inits_ptr); - // Allocate memory for the output const n_draws = num_chains * (save_warmup ? num_samples + num_warmup : num_samples); const n_out = n_draws * n_params; + + // Allocate memory for the output const out_ptr = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT); + deferredFree(out_ptr); - const inits_ptr = this.encodeInits(inits); + const err_ptr = this.m._malloc(4); + deferredFree(err_ptr); // Sample from the model - const err_ptr = this.m._malloc(4); const result = this.m._tinystan_sample( model, num_chains, @@ -310,8 +323,6 @@ export default class StanModel { if (result != 0) { this.handleError(err_ptr); } - this.m._free(err_ptr); - this.m._free(inits_ptr); const out_buffer = this.m.HEAPF64.subarray( out_ptr / Float64Array.BYTES_PER_ELEMENT, @@ -323,9 +334,6 @@ export default class StanModel { Array.from({ length: n_draws }, (_, j) => out_buffer[i + n_params * j]), ); - // Clean up - this.m._free(out_ptr); - let metric_array: number[][] | number[][][] | null = null; if (save_metric) { @@ -361,7 +369,6 @@ export default class StanModel { ); } } - this.m._free(metric_out); return { paramNames, draws, metric: metric_array }; });