Skip to content

Commit

Permalink
Move tensorflow lite python calls to ai-edge-litert.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 680610668
  • Loading branch information
pak-laura authored and copybara-github committed Oct 9, 2024
1 parent 7d1be69 commit c627902
Show file tree
Hide file tree
Showing 4 changed files with 798 additions and 607 deletions.
4 changes: 3 additions & 1 deletion chirp/projects/zoo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import tensorflow.compat.v1 as tf1
import tensorflow_hub as hub

from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import


def model_class_map() -> dict[str, Any]:
"""Get the mapping of model keys to classes."""
Expand Down Expand Up @@ -419,7 +421,7 @@ def from_config(cls, config: config_dict.ConfigDict) -> 'BirdNet':
with tempfile.NamedTemporaryFile() as tmpf:
model_file = epath.Path(config.model_path)
model_file.copy(tmpf.name, overwrite=True)
model = tf.lite.Interpreter(
model = tfl_interpreter.Interpreter(
tmpf.name, num_threads=config.num_tflite_threads
)
model.allocate_tensors()
Expand Down
5 changes: 3 additions & 2 deletions chirp/train_tests/frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from absl.testing import absltest
from absl.testing import parameterized
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import


class FrontendTest(parameterized.TestCase):
Expand Down Expand Up @@ -200,7 +201,7 @@ def test_tflite_stft_export(
tflite_float_model = converter.convert()

# Use the converted TFLite model.
interpreter = tf.lite.Interpreter(model_content=tflite_float_model)
interpreter = tfl_interpreter.Interpreter(model_content=tflite_float_model)
interpreter.allocate_tensors()
input_tensor = interpreter.get_input_details()[0]
output_tensor = interpreter.get_output_details()[0]
Expand Down Expand Up @@ -247,7 +248,7 @@ def test_simple_melspec(self):
tflite_float_model = converter.convert()

# Use the converted TFLite model.
interpreter = tf.lite.Interpreter(model_content=tflite_float_model)
interpreter = tfl_interpreter.Interpreter(model_content=tflite_float_model)
interpreter.allocate_tensors()
input_tensor = interpreter.get_input_details()[0]
output_tensor = interpreter.get_output_details()[0]
Expand Down
Loading

0 comments on commit c627902

Please sign in to comment.