-
Notifications
You must be signed in to change notification settings - Fork 157
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[fm-equalize-value-py-test] Introduce tests (#14001)
This commit introduces fm-equalize-value-py-test to ONE. ONE-DCO-1.0-Signed-off-by: seongwoo <[email protected]>
- Loading branch information
Showing
6 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
if(NOT ENABLE_TEST) | ||
return() | ||
endif(NOT ENABLE_TEST) | ||
|
||
set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_12_1") | ||
set(TEST_LIST_FILE "test.lst") | ||
|
||
get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR) | ||
get_target_property(FM_EQUALIZE_BIN_PATH fm-equalize BINARY_DIR) | ||
get_target_property(FME_DETECT_BIN_PATH fme-detect BINARY_DIR) | ||
get_target_property(DALGONA_BIN_PATH dalgona BINARY_DIR) | ||
get_target_property(FME_APPLY_BIN_PATH fme-apply BINARY_DIR) | ||
set(FM_EQUALIZE_BIN "${FM_EQUALIZE_BIN_PATH}/fm-equalize") | ||
set(FME_DETECT_BIN "${FME_DETECT_BIN_PATH}/fme-detect") | ||
set(DALGONA_BIN "${DALGONA_BIN_PATH}/dalgona") | ||
set(FME_APPLY_BIN "${FME_APPLY_BIN_PATH}/fme-apply") | ||
|
||
macro(eval RECIPE) | ||
set(CIRCLE_FILE "${RECIPE}.circle") | ||
set(CIRCLE_PATH "${ARTIFACTS_BIN_PATH}/${CIRCLE_FILE}") | ||
set(OPT_CIRCLE_FILE "${RECIPE}.opt.circle") | ||
set(OPT_CIRCLE_OUTPUT_PATH "${CMAKE_CURRENT_BINARY_DIR}/${OPT_CIRCLE_FILE}") | ||
|
||
# Run circle2circle for fusing instance normalization. | ||
add_custom_command(OUTPUT ${OPT_CIRCLE_OUTPUT_PATH} | ||
COMMAND $<TARGET_FILE:circle2circle> --fuse_instnorm ${CIRCLE_PATH} ${OPT_CIRCLE_OUTPUT_PATH} | ||
DEPENDS $<TARGET_FILE:circle2circle> ${CIRCLE_PATH} | ||
COMMENT "Generate ${OPT_CIRCLE_FILE} for fusing instance normalization." | ||
) | ||
|
||
set(AFTER_CIRCLE_FILE "${RECIPE}.after.circle") | ||
set(AFETR_PATTERN_FILE "${RECIPE}.after.json") | ||
set(AFTER_CIRCLE_OUTPUT_PATH "${CMAKE_CURRENT_BINARY_DIR}/${AFTER_CIRCLE_FILE}") | ||
set(AFTER_CIRCLE_PATTERN_PATH "${CMAKE_CURRENT_BINARY_DIR}/${AFETR_PATTERN_FILE}") | ||
|
||
# Apply fm-equalize | ||
add_custom_command(OUTPUT ${AFTER_CIRCLE_OUTPUT_PATH} ${AFTER_CIRCLE_PATTERN_PATH} | ||
COMMAND ${VIRTUALENV}/bin/python ${FM_EQUALIZE_BIN} -i ${OPT_CIRCLE_OUTPUT_PATH} | ||
-o ${AFTER_CIRCLE_OUTPUT_PATH} -f ${AFTER_CIRCLE_PATTERN_PATH} | ||
--fme_detect ${FME_DETECT_BIN} --dalgona ${DALGONA_BIN} | ||
--fme_apply ${FME_APPLY_BIN} | ||
DEPENDS ${FM_EQUALIZE_BIN} ${OPT_CIRCLE_OUTPUT_PATH} ${FME_APPLY_BIN} ${FME_DETECT_BIN} ${DALGONA_BIN} | ||
COMMENT "Apply fm-equalize to ${OPT_CIRCLE_OUTPUT_PATH}" | ||
) | ||
|
||
# depends | ||
list(APPEND TEST_DEPS ${AFTER_CIRCLE_OUTPUT_PATH} ${AFTER_CIRCLE_PATTERN_PATH}) | ||
endmacro(eval) | ||
|
||
# Read "test.lst" | ||
include("test.lst") | ||
# Read "test.local.lst" if exists | ||
include("test.local.lst" OPTIONAL) | ||
|
||
add_custom_target(fm_equalize_value_py_test_files ALL DEPENDS ${TEST_DEPS}) | ||
add_dependencies(fm_equalize_value_py_test_files common_artifacts_deps) | ||
|
||
add_test(NAME fm_equalize_value_py_test | ||
COMMAND ${VIRTUALENV}/bin/python -m pytest -sv test_luci_eval.py | ||
--test_list ${TEST_LIST_FILE} | ||
--tflite_dir ${ARTIFACTS_BIN_PATH} | ||
--circle_dir ${CMAKE_CURRENT_BINARY_DIR} | ||
--luci_eval_driver $<TARGET_FILE:luci_eval_driver> | ||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# fm-equalize-value-py-test | ||
|
||
`fm-equalize-value-py-test` validates execution result values of original tflite model and | ||
circle model generated with _fm-equalize_. | ||
|
||
The test proceeds as follows: | ||
|
||
Step 0: Use tflite and circle file in 'common-artifacts' folder as the source model. | ||
- tflite file is used as to generate reference execution result | ||
- circle file is used as source of fm-equalize to apply | ||
|
||
Step 1: Run _fm-equalize_. | ||
- "modelfile.circle" -> fm-equalize -> "modelfile.after.circle" | ||
|
||
Step 2: Run TFLite interpreter and luci-interpreter for the source tflite and circle, respectively. | ||
(with the same input tensors filled with random values) | ||
- "modelfile.tflite" ------> TFLite interpreter -> Execution result 1 | ||
- "modelfile.after.circle" -> luci-interpreter ---> Execution result 2 | ||
|
||
Step 3: Compare the execution result 1 and 2. Test is PASSED if results are sames. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import re | ||
|
||
|
||
def extract_test_args(s): | ||
p = re.compile('eval\\((.*)\\)') | ||
result = p.search(s) | ||
return result.group(1) | ||
|
||
|
||
def pytest_addoption(parser): | ||
parser.addoption("--test_list", action="store", help="Path to test list") | ||
parser.addoption("--tflite_dir", | ||
action="store", | ||
help="Directory including tflite file") | ||
parser.addoption("--circle_dir", | ||
action="store", | ||
help="Directory including circle file") | ||
parser.addoption("--luci_eval_driver", | ||
action="store", | ||
help="Path to luci eval driver") | ||
|
||
|
||
def pytest_generate_tests(metafunc): | ||
list_path = metafunc.config.getoption('test_list') | ||
tflite_dir = metafunc.config.getoption('tflite_dir') | ||
circle_dir = metafunc.config.getoption('circle_dir') | ||
eval_driver_path = metafunc.config.getoption('luci_eval_driver') | ||
if list_path is None: | ||
tests_default_tol = [] | ||
else: | ||
with open(list_path) as f: | ||
contents = [line.rstrip() for line in f] | ||
|
||
comment_removed = [line for line in contents if not line.startswith('#')] | ||
newline_removed = [line for line in comment_removed if line.startswith('eval(')] | ||
test_args = [extract_test_args(line) for line in newline_removed] | ||
# eval(TEST_NAME) | ||
tests_default_tol = [(arg.split()[0], tflite_dir, circle_dir, eval_driver_path) | ||
for arg in test_args] | ||
|
||
if 'test_name' in metafunc.fixturenames: | ||
metafunc.parametrize('test_name,tflite_dir,circle_dir,eval_driver_path', | ||
tests_default_tol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
require("common-artifacts") | ||
require("luci-eval-driver") | ||
require("circle2circle") | ||
require("fm-equalize") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# | ||
# Format: | ||
# eval(MODEL) | ||
# MODEL: tflite model file name in build/compiler/common-artifacts folder. | ||
# | ||
|
||
eval(Conv2D_007) | ||
eval(FullyConnected_010) | ||
eval(Net_Conv_TConv_000) | ||
eval(Net_DConv_Conv_000) | ||
eval(Net_InstNorm_Conv_000) | ||
eval(Net_Conv_Pad_000) | ||
|
||
# Values could be mismatch according to input ranges. | ||
# eval(Net_Conv_Gelu_000) | ||
# eval(Net_FullyConnected_Gelu_000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import subprocess | ||
import os | ||
|
||
|
||
def luci_eval_verify(test_name, | ||
tflite_dir, | ||
circle_dir, | ||
eval_driver, | ||
rtolf32=1e-5, | ||
atolf32=1e-5): | ||
tflite_model = os.path.join(tflite_dir, test_name + ".tflite") | ||
circle_model = os.path.join(circle_dir, test_name + ".after.circle") | ||
|
||
# NOTE reuse f32 value as int value too | ||
rtolint = int(rtolf32) | ||
atolint = int(atolf32) | ||
|
||
# Build TFLite interpreter. | ||
interpreter = tf.lite.Interpreter(tflite_model) | ||
interpreter.allocate_tensors() | ||
|
||
# Read SignatureDef and get output tensor id orders for remapping | ||
full_signatures = interpreter._get_full_signature_list() | ||
full_signatures_outputs_remap = None | ||
if full_signatures != None: | ||
signature_serving_default = full_signatures.get('serving_default', None) | ||
if signature_serving_default != None: | ||
signature_outputs = signature_serving_default['outputs'] | ||
|
||
full_signatures_outputs_remap = [] | ||
for index, (key, value) in enumerate(signature_outputs.items()): | ||
full_signatures_outputs_remap.append(value) | ||
|
||
# Generate random input data. | ||
num_inputs = len(interpreter.get_input_details()) | ||
for i in range(num_inputs): | ||
input_details = interpreter.get_input_details()[i] | ||
if input_details["dtype"] == np.float32: | ||
input_data = np.array(np.random.random_sample(input_details["shape"]), | ||
input_details["dtype"]) | ||
elif input_details["dtype"] == np.uint8: | ||
input_data = np.array(np.random.randint(0, 256, size=input_details["shape"]), | ||
input_details["dtype"]) | ||
elif input_details["dtype"] == np.int16: | ||
input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]), | ||
input_details["dtype"]) | ||
elif input_details["dtype"] == np.int32: | ||
input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]), | ||
input_details["dtype"]) | ||
elif input_details["dtype"] == np.int64: | ||
input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]), | ||
input_details["dtype"]) | ||
elif input_details["dtype"] == np.bool_: | ||
input_data = np.array( | ||
np.random.choice(a=[True, False], size=input_details["shape"]), | ||
input_details["dtype"]) | ||
else: | ||
assert False, "Unsupported input dtype" | ||
|
||
interpreter.set_tensor(input_details["index"], input_data) | ||
input_data.tofile(circle_model + ".input" + str(i)) | ||
|
||
# Do inference | ||
interpreter.invoke() | ||
|
||
# Execute luci interpreter. | ||
subprocess.run([ | ||
eval_driver, circle_model, | ||
str(num_inputs), circle_model + ".input", circle_model + ".output" | ||
], | ||
check=True) | ||
|
||
# Compare the results. | ||
inpt_output_details = interpreter.get_output_details() | ||
for idx in range(len(inpt_output_details)): | ||
output_details = inpt_output_details[idx] | ||
output_data = np.fromfile(circle_model + ".output" + str(idx), | ||
output_details["dtype"]) | ||
shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r') | ||
output_shape = [int(i) for i in shape_file.read().split(',')] | ||
luci_output_data = np.reshape(output_data, output_shape) | ||
output_tensor = output_details["index"] | ||
if full_signatures_outputs_remap != None: | ||
output_tensor = full_signatures_outputs_remap[idx] | ||
intp_output_data = interpreter.get_tensor(output_tensor) | ||
err_msg = "Execution result of " + tflite_model + " does not match with " + circle_model | ||
if output_details["dtype"] == np.uint8: | ||
assert np.allclose(luci_output_data, | ||
intp_output_data, | ||
rtol=rtolint, | ||
atol=atolint), err_msg | ||
elif output_details["dtype"] == np.float32: | ||
assert np.allclose(luci_output_data, | ||
intp_output_data, | ||
rtol=rtolf32, | ||
atol=atolf32), err_msg | ||
elif output_details["dtype"] == np.int64: | ||
assert np.allclose(luci_output_data, | ||
intp_output_data, | ||
rtol=rtolint, | ||
atol=atolint), err_msg | ||
elif output_details["dtype"] == np.int32: | ||
assert np.allclose(luci_output_data, | ||
intp_output_data, | ||
rtol=rtolint, | ||
atol=atolint), err_msg | ||
elif output_details["dtype"] == np.int16: | ||
assert np.allclose(luci_output_data, | ||
intp_output_data, | ||
rtol=rtolint, | ||
atol=atolint), err_msg | ||
elif output_details["dtype"] == np.bool_: | ||
assert np.allclose(luci_output_data, intp_output_data, rtol=0, | ||
atol=0), err_msg | ||
else: | ||
assert False, "Unsupported data type: " + output_details["dtype"] | ||
|
||
|
||
# arguments must be in sync with `conftest.py` | ||
def test_luci_eval(test_name: str, tflite_dir: str, circle_dir: str, | ||
eval_driver_path: str): | ||
luci_eval_verify(test_name, tflite_dir, circle_dir, eval_driver_path) |