diff --git a/compiler/circle-interpreter-cffi-test/CMakeLists.txt b/compiler/circle-interpreter-cffi-test/CMakeLists.txt new file mode 100644 index 00000000000..f2704baaeb5 --- /dev/null +++ b/compiler/circle-interpreter-cffi-test/CMakeLists.txt @@ -0,0 +1,17 @@ +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_12_1") +set(TEST_LIST_FILE "test.lst") + +get_target_property(ARTIFACTS_PATH testDataGenerator BINARY_DIR) + +add_test( + NAME circle_interpreter_cffi_test + COMMAND ${VIRTUALENV}/bin/python infer.py + --lib_path $ + --test_list ${TEST_LIST_FILE} + --artifact_dir ${ARTIFACTS_PATH} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) diff --git a/compiler/circle-interpreter-cffi-test/README.md b/compiler/circle-interpreter-cffi-test/README.md new file mode 100644 index 00000000000..ec6f921d1be --- /dev/null +++ b/compiler/circle-interpreter-cffi-test/README.md @@ -0,0 +1,11 @@ +# circle-interpreter-cffi-test + +The `circle_interpereter_cffi` library wrapped with CFFI is designed to expose an existing `luci-interpreter` class to Python. It simplifies the integration of the class by creating a Python-compatible interface using CFFI. + +`circle-interpreter-cffi-test` ensures that the Python bindings for the C++ library correctly. Specifically, it verifies that: + +1. The CFFI-wrapped library can succesfully load the circle model. +2. Inputs passed from Python are correctly interpreted by the `luci-interpreter`. +3. The output generated by the interpreter matches the expected results. + +This test provides confidence that `luci-interpter`, when accessed through the Python interface, produces the same results as the original implementation. diff --git a/compiler/circle-interpreter-cffi-test/infer.py b/compiler/circle-interpreter-cffi-test/infer.py new file mode 100644 index 00000000000..b4e2aee8bc2 --- /dev/null +++ b/compiler/circle-interpreter-cffi-test/infer.py @@ -0,0 +1,112 @@ +import argparse +import h5py +import numpy as np +from pathlib import Path +import re +import sys + +############ Managing paths for the artifacts required by the test. + + +def extract_test_args(s): + p = re.compile('eval\\((.*)\\)') + result = p.search(s) + return result.group(1) + + +parser = argparse.ArgumentParser() +parser.add_argument('--lib_path', type=str, required=True) +parser.add_argument('--test_list', type=str, required=True) +parser.add_argument('--artifact_dir', type=str, required=True) +args = parser.parse_args() + +with open(args.test_list) as f: + contents = [line.rstrip() for line in f] +# remove newline and comments. +eval_lines = [line for line in contents if line.startswith('eval(')] +test_args = [extract_test_args(line) for line in eval_lines] +test_models = [Path(args.artifact_dir) / f'{arg}.circle' for arg in test_args] +input_data = [ + Path(args.artifact_dir) / f'{arg}.opt/metadata/tc/input.h5' for arg in test_args +] +expected_output_data = [ + Path(args.artifact_dir) / f'{arg}.opt/metadata/tc/expected.h5' for arg in test_args +] + +############ CFFI test + +from cffi import FFI + +ffi = FFI() +ffi.cdef(""" + typedef struct InterpreterWrapper InterpreterWrapper; + + const char *get_last_error(void); + void clear_last_error(void); + InterpreterWrapper *Interpreter_new(const uint8_t *data, const size_t data_size); + void Interpreter_delete(InterpreterWrapper *intp); + void Interpreter_interpret(InterpreterWrapper *intp); + void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data, size_t input_size); + void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output, size_t output_size); +""") +C = ffi.dlopen(args.lib_path) + + +def check_for_errors(): + error_message = ffi.string(C.get_last_error()).decode('utf-8') + if error_message: + C.clear_last_error() + raise RuntimeError(f'C++ Exception: {error_message}') + + +def error_checked(func): + """ + Decorator to wrap functions with error checking. + """ + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + check_for_errors() + return result + + return wrapper + + +Interpreter_new = error_checked(C.Interpreter_new) +Interpreter_delete = error_checked(C.Interpreter_delete) +Interpreter_interpret = error_checked(C.Interpreter_interpret) +Interpreter_writeInputTensor = error_checked(C.Interpreter_writeInputTensor) +Interpreter_readOutputTensor = error_checked(C.Interpreter_readOutputTensor) + +for idx, model_path in enumerate(test_models): + with open(model_path, "rb") as f: + model_data = ffi.from_buffer(bytearray(f.read())) + + try: + intp = Interpreter_new(model_data, len(model_data)) + + # Set inputs + h5 = h5py.File(input_data[idx]) + input_values = h5.get('value') + input_num = len(input_values) + for input_idx in range(input_num): + arr = np.array(input_values.get(str(input_idx))) + c_arr = ffi.from_buffer(arr) + Interpreter_writeInputTensor(intp, input_idx, c_arr, arr.nbytes) + # Do inference + Interpreter_interpret(intp) + # Check outputs + h5 = h5py.File(expected_output_data[idx]) + output_values = h5.get('value') + output_num = len(output_values) + for output_idx in range(output_num): + arr = np.array(output_values.get(str(output_idx))) + result = np.empty(arr.shape, dtype=arr.dtype) + Interpreter_readOutputTensor(intp, output_idx, ffi.from_buffer(result), + arr.nbytes) + if not np.allclose(result, arr): + raise RuntimeError("Wrong outputs") + + Interpreter_delete(intp) + except RuntimeError as e: + print(e) + sys.exit(-1) diff --git a/compiler/circle-interpreter-cffi-test/requires.cmake b/compiler/circle-interpreter-cffi-test/requires.cmake new file mode 100644 index 00000000000..8d8585b470d --- /dev/null +++ b/compiler/circle-interpreter-cffi-test/requires.cmake @@ -0,0 +1,2 @@ +require("common-artifacts") +require("circle-interpreter") diff --git a/compiler/circle-interpreter-cffi-test/test.lst b/compiler/circle-interpreter-cffi-test/test.lst new file mode 100644 index 00000000000..97ec610ad5f --- /dev/null +++ b/compiler/circle-interpreter-cffi-test/test.lst @@ -0,0 +1,17 @@ +eval(Add_000) +eval(Add_U8_000) +eval(AveragePool2D_000) +eval(Concatenation_000) +eval(Conv2D_000) +eval(Conv2D_001) +eval(Conv2D_002) +eval(DepthwiseConv2D_000) +eval(FullyConnected_000) +eval(FullyConnected_001) +eval(MaxPool2D_000) +eval(Mul_000) +eval(Pad_000) +eval(Reshape_000) +eval(Reshape_001) +eval(Reshape_002) +eval(Softmax_000)