Skip to content

Commit

Permalink
cache signature runner
Browse files Browse the repository at this point in the history
  • Loading branch information
duy-maimanh committed Nov 2, 2023
1 parent bddb8a9 commit 58c7e72
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions lib/src/interpreter.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import 'tensor.dart';
/// TensorFlowLite interpreter for running inference on a model.
class Interpreter {
final Pointer<TfLiteInterpreter> _interpreter;
final Map<String, Pointer<TfLiteSignatureRunner>> _signatureRunners = {};
bool _deleted = false;
bool _allocated = false;
int _lastNativeInferenceDurationMicroSeconds = 0;
Expand Down Expand Up @@ -327,6 +328,10 @@ class Interpreter {
void resetVariableTensors() {
checkState(_deleted,
message: 'Should not access delegate after it has been closed.');
// loop through signatureRunners and delete them
_signatureRunners.forEach((key, value) {
_deleteSignatureRunner(value);
});
tfliteBinding.TfLiteInterpreterResetVariableTensors(_interpreter);
}

Expand Down Expand Up @@ -413,12 +418,12 @@ class Interpreter {

final Pointer<TfLiteSignatureRunner> signatureRunner;

try {
// TODO: Should we cache the signature runner?
// check if signature key exists
if (!_signatureRunners.containsKey(signatureKey)) {
signatureRunner = getSignatureRunner(signatureKey);
_allocated = false;
} catch (e) {
throw ArgumentError('Input error: Signature key is not valid.');
_signatureRunners[signatureKey] = signatureRunner;
} else {
signatureRunner = _signatureRunners[signatureKey]!;
}

inputs.forEach((key, value) {
Expand Down Expand Up @@ -454,10 +459,17 @@ class Interpreter {
}

/// Delete signature runner. Should be called before interpreter is deleted.
void deleteSignatureRunner(Pointer<TfLiteSignatureRunner> signatureRunner) {
void _deleteSignatureRunner(Pointer<TfLiteSignatureRunner> signatureRunner) {
tfliteBinding.TfLiteSignatureRunnerDelete(signatureRunner);
}

void deleteSignatureRunner(String signatureKey) {
if (_signatureRunners.containsKey(signatureKey)) {
_deleteSignatureRunner(_signatureRunners[signatureKey]!);
_signatureRunners.remove(signatureKey);
}
}

/// Returns the address to the interpreter
int get address => _interpreter.address;

Expand Down

0 comments on commit 58c7e72

Please sign in to comment.