Skip to content

Commit

Permalink
Merge pull request #17 from flatironinstitute/fix/memory-leaks
Browse files Browse the repository at this point in the history
Stop leaking memory if sampling fails
  • Loading branch information
WardBrian authored May 16, 2024
2 parents 92c4cca + f80f0c7 commit 70e820c
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions gui/src/app/tinystan/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ export default class StanModel {
throw new Error(err_msg);
}

private encodeInits(inits: string | StanVariableInputs | string[] | StanVariableInputs[]): cstr {
private encodeInits(
inits: string | StanVariableInputs | string[] | StanVariableInputs[],
): cstr {
if (Array.isArray(inits)) {
return this.encodeString(
inits.map(i => string_safe_jsonify(i)).join(this.sep),
Expand All @@ -170,10 +172,22 @@ export default class StanModel {
}
}

/**
* withModel serves as something akin to a context manager in
* Python. It accepts the arguments needed to construct a model
* (data and seed) and a callback.
*
* The callback takes in the model and a deferredFree function.
* The memory for the allocated model and any pointers which are "registered"
* by calling deferredFree will be cleaned up when the callback completes,
* regardless of if this is a normal return or an exception.
*
* The result of the callback is then returned or re-thrown.
*/
private withModel<T>(
data: string | StanVariableInputs,
seed: number,
f: (model: model_ptr) => T,
f: (model: model_ptr, deferredFree: (p: ptr | cstr) => void) => T,
): T {
const data_ptr = this.encodeString(string_safe_jsonify(data));
const err_ptr = this.m._malloc(4);
Expand All @@ -184,9 +198,14 @@ export default class StanModel {
this.handleError(err_ptr);
}
this.m._free(err_ptr);

const ptrs: (ptr | cstr)[] = [];
const deferredFree = (p: ptr | cstr) => ptrs.push(p);

try {
return f(model);
return f(model, deferredFree);
} finally {
ptrs.forEach(p => this.m._free(p));
this.m._tinystan_destroy_model(model);
}
}
Expand Down Expand Up @@ -234,7 +253,7 @@ export default class StanModel {
seed_ = Math.floor(Math.random() * (2 ^ 32));
}

return this.withModel(data, seed_, model => {
return this.withModel(data, seed_, (model, deferredFree) => {
// Get the parameter names
const rawParamNames = this.m.UTF8ToString(
this.m._tinystan_model_param_names(model),
Expand Down Expand Up @@ -265,17 +284,23 @@ export default class StanModel {
num_chains * free_params * Float64Array.BYTES_PER_ELEMENT,
);
}
deferredFree(metric_out);

const inits_ptr = this.encodeInits(inits);
deferredFree(inits_ptr);

// Allocate memory for the output
const n_draws =
num_chains * (save_warmup ? num_samples + num_warmup : num_samples);
const n_out = n_draws * n_params;

// Allocate memory for the output
const out_ptr = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT);
deferredFree(out_ptr);

const inits_ptr = this.encodeInits(inits);
const err_ptr = this.m._malloc(4);
deferredFree(err_ptr);

// Sample from the model
const err_ptr = this.m._malloc(4);
const result = this.m._tinystan_sample(
model,
num_chains,
Expand Down Expand Up @@ -310,8 +335,6 @@ export default class StanModel {
if (result != 0) {
this.handleError(err_ptr);
}
this.m._free(err_ptr);
this.m._free(inits_ptr);

const out_buffer = this.m.HEAPF64.subarray(
out_ptr / Float64Array.BYTES_PER_ELEMENT,
Expand All @@ -323,9 +346,6 @@ export default class StanModel {
Array.from({ length: n_draws }, (_, j) => out_buffer[i + n_params * j]),
);

// Clean up
this.m._free(out_ptr);

let metric_array: number[][] | number[][][] | null = null;

if (save_metric) {
Expand Down Expand Up @@ -361,7 +381,6 @@ export default class StanModel {
);
}
}
this.m._free(metric_out);

return { paramNames, draws, metric: metric_array };
});
Expand Down

0 comments on commit 70e820c

Please sign in to comment.