Skip to content

Commit

Permalink
Stop leaking memory if sampling fails
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed May 16, 2024
1 parent 92c4cca commit 141d768
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions gui/src/app/tinystan/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ export default class StanModel {
throw new Error(err_msg);
}

private encodeInits(inits: string | StanVariableInputs | string[] | StanVariableInputs[]): cstr {
private encodeInits(
inits: string | StanVariableInputs | string[] | StanVariableInputs[],
): cstr {
if (Array.isArray(inits)) {
return this.encodeString(
inits.map(i => string_safe_jsonify(i)).join(this.sep),
Expand All @@ -173,7 +175,7 @@ export default class StanModel {
private withModel<T>(
data: string | StanVariableInputs,
seed: number,
f: (model: model_ptr) => T,
f: (model: model_ptr, deferredFree: (p: ptr | cstr) => void) => T,
): T {
const data_ptr = this.encodeString(string_safe_jsonify(data));
const err_ptr = this.m._malloc(4);
Expand All @@ -184,9 +186,14 @@ export default class StanModel {
this.handleError(err_ptr);
}
this.m._free(err_ptr);

const ptrs: (ptr | cstr)[] = [];
const deferredFree = (p: ptr | cstr) => ptrs.push(p);

try {
return f(model);
return f(model, deferredFree);
} finally {
ptrs.forEach(p => this.m._free(p));
this.m._tinystan_destroy_model(model);
}
}
Expand Down Expand Up @@ -234,7 +241,7 @@ export default class StanModel {
seed_ = Math.floor(Math.random() * (2 ^ 32));
}

return this.withModel(data, seed_, model => {
return this.withModel(data, seed_, (model, deferredFree) => {
// Get the parameter names
const rawParamNames = this.m.UTF8ToString(
this.m._tinystan_model_param_names(model),
Expand Down Expand Up @@ -265,17 +272,23 @@ export default class StanModel {
num_chains * free_params * Float64Array.BYTES_PER_ELEMENT,
);
}
deferredFree(metric_out);

const inits_ptr = this.encodeInits(inits);
deferredFree(inits_ptr);

// Allocate memory for the output
const n_draws =
num_chains * (save_warmup ? num_samples + num_warmup : num_samples);
const n_out = n_draws * n_params;

// Allocate memory for the output
const out_ptr = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT);
deferredFree(out_ptr);

const inits_ptr = this.encodeInits(inits);
const err_ptr = this.m._malloc(4);
deferredFree(err_ptr);

// Sample from the model
const err_ptr = this.m._malloc(4);
const result = this.m._tinystan_sample(
model,
num_chains,
Expand Down Expand Up @@ -310,8 +323,6 @@ export default class StanModel {
if (result != 0) {
this.handleError(err_ptr);
}
this.m._free(err_ptr);
this.m._free(inits_ptr);

const out_buffer = this.m.HEAPF64.subarray(
out_ptr / Float64Array.BYTES_PER_ELEMENT,
Expand All @@ -323,9 +334,6 @@ export default class StanModel {
Array.from({ length: n_draws }, (_, j) => out_buffer[i + n_params * j]),
);

// Clean up
this.m._free(out_ptr);

let metric_array: number[][] | number[][][] | null = null;

if (save_metric) {
Expand Down Expand Up @@ -361,7 +369,6 @@ export default class StanModel {
);
}
}
this.m._free(metric_out);

return { paramNames, draws, metric: metric_array };
});
Expand Down

0 comments on commit 141d768

Please sign in to comment.