From 6b26aa4879cca4e450e565734a2fb4fb315f3da9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 00:37:16 +0000 Subject: [PATCH 01/18] Test IR roundtrip with model zoo --- tools/ir/model_zoo_test/model_zoo_test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tools/ir/model_zoo_test/model_zoo_test.py diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py new file mode 100644 index 000000000..e69de29bb From d5108244612f675a707839e36e863731453cce94 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 00:47:18 +0000 Subject: [PATCH 02/18] Create test script --- tools/ir/model_zoo_test/model_zoo_test.py | 68 +++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index e69de29bb..dcdde0968 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -0,0 +1,68 @@ +"""Test IR roundtrip with ONNX model zoo.""" + +from __future__ import annotations + +import gc +import sys +import tempfile +import time + +import onnx +from onnx import hub + +import onnxscript.testing +from onnxscript import ir + + +def test_model(model_info: hub.ModelInfo) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + hub.set_dir(temp_dir) + model_name = model_info.model + model = hub.load(model_name) + assert model is not None + onnx.checker.check_model(model) + # Fix the missing graph name of some test models + model.graph.name = "main_graph" + + # Profile the serialization and deserialization process + ir_model = ir.serde.deserialize_model(model) + serialized = ir.serde.serialize_model(ir_model) + onnxscript.testing.assert_onnx_proto_equal(serialized, model) + onnx.checker.check_model(serialized) + + +def main(): + model_list = hub.list_models() + print(f"=== Testing IR on {len(model_list)} models ===") + + # run checker on each model + failed_models = [] + failed_messages = [] + for model_info in model_list: + start = time.time() + model_name = model_info.model + model_path = model_info.model_path + print(f"-----------------Testing: {model_name} @ {model_path}-----------------") + try: + test_model(model_info) + print(f"[PASS]: {model_name} roundtrip test passed.") + except Exception as e: + print(f"[FAIL]: {e}") + failed_models.append(model_name) + failed_messages.append((model_name, e)) + end = time.time() + print(f"--------------Time used: {end - start} secs-------------") + # enable gc collection to prevent MemoryError by loading too many large models + gc.collect() + + if len(failed_models) == 0: + print(f"{len(model_list)} models have been checked.") + else: + print(f"In all {len(model_list)} models, {len(failed_models)} models failed") + for model_name, error in failed_messages: + print(f"{model_name} failed because: {error}") + sys.exit(1) + + +if __name__ == "__main__": + main() From 80f3fd0ab2bf71b05690c39ba15339c5fee58f95 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 00:51:14 +0000 Subject: [PATCH 03/18] format --- tools/ir/model_zoo_test/model_zoo_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index dcdde0968..dd4a21adf 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -42,7 +42,7 @@ def main(): start = time.time() model_name = model_info.model model_path = model_info.model_path - print(f"-----------------Testing: {model_name} @ {model_path}-----------------") + print(f"----Testing: {model_name} @ {model_path}----") try: test_model(model_info) print(f"[PASS]: {model_name} roundtrip test passed.") From f73358bffc601e9b00c618f64cbd99738eb5ea83 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:02:41 +0000 Subject: [PATCH 04/18] multiprocess --- tools/ir/model_zoo_test/model_zoo_test.py | 66 +++++++++++++++++------ 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index dd4a21adf..f90ca3429 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -2,7 +2,9 @@ from __future__ import annotations +import argparse import gc +import multiprocessing.pool import sys import tempfile import time @@ -14,7 +16,7 @@ from onnxscript import ir -def test_model(model_info: hub.ModelInfo) -> None: +def test_model(model_info: hub.ModelInfo) -> float: with tempfile.TemporaryDirectory() as temp_dir: hub.set_dir(temp_dir) model_name = model_info.model @@ -25,36 +27,66 @@ def test_model(model_info: hub.ModelInfo) -> None: model.graph.name = "main_graph" # Profile the serialization and deserialization process + start = time.time() ir_model = ir.serde.deserialize_model(model) serialized = ir.serde.serialize_model(ir_model) + end = time.time() onnxscript.testing.assert_onnx_proto_equal(serialized, model) onnx.checker.check_model(serialized) + return end - start + + +def run_one_test(model_info: hub.ModelInfo) -> tuple[str, Exception | None]: + start = time.time() + model_name = model_info.model + model_path = model_info.model_path + message = f"----Testing: {model_name} @ {model_path}----" + try: + time_passed = test_model(model_info) + message += green(f"\n[PASS]: {model_name} roundtrip test passed.") + except Exception as e: + time_passed = -1 + message += red(f"\n[FAIL]: {e}") + else: + e = None + end = time.time() + message += f"\n-------Time used: {end - start} secs, roundtrip: {time_passed} secs -------" + print(message, flush=True) + # enable gc collection to prevent MemoryError by loading too many large models + gc.collect() + return model_name, e + + +def green(text: str) -> str: + return f"\033[32m{text}\033[0m" + + +def red(text: str) -> str: + return f"\033[31m{text}\033[0m" def main(): + parser = argparse.ArgumentParser(description="Test IR roundtrip with ONNX model zoo.") + parser.add_argument( + "--jobs", + type=int, + default=1, + help="Number of parallel jobs to run. Default is 1.", + ) + args = parser.parse_args() + model_list = hub.list_models() print(f"=== Testing IR on {len(model_list)} models ===") # run checker on each model failed_models = [] failed_messages = [] - for model_info in model_list: - start = time.time() - model_name = model_info.model - model_path = model_info.model_path - print(f"----Testing: {model_name} @ {model_path}----") - try: - test_model(model_info) - print(f"[PASS]: {model_name} roundtrip test passed.") - except Exception as e: - print(f"[FAIL]: {e}") + # Use multi-threading to speed up the testing process + results = multiprocessing.pool.ThreadPool(args.jobs).map(run_one_test, model_list) + for model_name, error in results: + if error is not None: failed_models.append(model_name) - failed_messages.append((model_name, e)) - end = time.time() - print(f"--------------Time used: {end - start} secs-------------") - # enable gc collection to prevent MemoryError by loading too many large models - gc.collect() - + failed_messages.append((model_name, error)) if len(failed_models) == 0: print(f"{len(model_list)} models have been checked.") else: From 0ea28241f6eb0a36db932fa2e36f392e10e8aab3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:04:43 +0000 Subject: [PATCH 05/18] tempdir --- tools/ir/model_zoo_test/model_zoo_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index f90ca3429..ed7ba6fee 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -17,10 +17,8 @@ def test_model(model_info: hub.ModelInfo) -> float: - with tempfile.TemporaryDirectory() as temp_dir: - hub.set_dir(temp_dir) - model_name = model_info.model - model = hub.load(model_name) + model_name = model_info.model + model = hub.load(model_name) assert model is not None onnx.checker.check_model(model) # Fix the missing graph name of some test models @@ -82,7 +80,9 @@ def main(): failed_models = [] failed_messages = [] # Use multi-threading to speed up the testing process - results = multiprocessing.pool.ThreadPool(args.jobs).map(run_one_test, model_list) + with tempfile.TemporaryDirectory() as temp_dir: + hub.set_dir(temp_dir) + results = multiprocessing.pool.ThreadPool(args.jobs).map(run_one_test, model_list) for model_name, error in results: if error is not None: failed_models.append(model_name) From 6ff330bd76afe1611bb697387eb8eb2d5a1311a9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:05:49 +0000 Subject: [PATCH 06/18] stdout --- tools/ir/model_zoo_test/model_zoo_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index ed7ba6fee..c67170051 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse +import contextlib import gc import multiprocessing.pool import sys @@ -18,7 +19,8 @@ def test_model(model_info: hub.ModelInfo) -> float: model_name = model_info.model - model = hub.load(model_name) + with contextlib.redirect_stdout(None): + model = hub.load(model_name) assert model is not None onnx.checker.check_model(model) # Fix the missing graph name of some test models From 8b7f852f5b9e8edbfd7ba2055a8ed40cd15c28c4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:10:58 +0000 Subject: [PATCH 07/18] more error ,sg --- tools/ir/model_zoo_test/model_zoo_test.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index c67170051..a63065175 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -9,6 +9,7 @@ import sys import tempfile import time +import traceback import onnx from onnx import hub @@ -44,11 +45,13 @@ def run_one_test(model_info: hub.ModelInfo) -> tuple[str, Exception | None]: try: time_passed = test_model(model_info) message += green(f"\n[PASS]: {model_name} roundtrip test passed.") - except Exception as e: + except Exception as e: # noqa: F841 time_passed = -1 - message += red(f"\n[FAIL]: {e}") + error = traceback.format_exc() + message += red(f"\n[FAIL]: {error}") else: e = None + error = None end = time.time() message += f"\n-------Time used: {end - start} secs, roundtrip: {time_passed} secs -------" print(message, flush=True) @@ -81,10 +84,10 @@ def main(): # run checker on each model failed_models = [] failed_messages = [] - # Use multi-threading to speed up the testing process + # Use multi-processing to speed up the testing process with tempfile.TemporaryDirectory() as temp_dir: hub.set_dir(temp_dir) - results = multiprocessing.pool.ThreadPool(args.jobs).map(run_one_test, model_list) + results = multiprocessing.pool.Pool(args.jobs).map(run_one_test, model_list) for model_name, error in results: if error is not None: failed_models.append(model_name) From 16afc3cca60d18984eade74d951b73708294759a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:12:19 +0000 Subject: [PATCH 08/18] Usage --- tools/ir/model_zoo_test/model_zoo_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index a63065175..89a716417 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -1,4 +1,8 @@ -"""Test IR roundtrip with ONNX model zoo.""" +"""Test IR roundtrip with ONNX model zoo. + +Usage: + python model_zoo_test.py --jobs 8 +""" from __future__ import annotations From 4a4f89926bd46b38a388267be3e8831986759698 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:13:57 +0000 Subject: [PATCH 09/18] fix --- tools/ir/model_zoo_test/model_zoo_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 89a716417..2fdaa8fc0 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -49,19 +49,19 @@ def run_one_test(model_info: hub.ModelInfo) -> tuple[str, Exception | None]: try: time_passed = test_model(model_info) message += green(f"\n[PASS]: {model_name} roundtrip test passed.") - except Exception as e: # noqa: F841 + except Exception as e: time_passed = -1 - error = traceback.format_exc() - message += red(f"\n[FAIL]: {error}") + stack_trace = traceback.format_exc() + message += red(f"\n[FAIL]: {stack_trace}") + error = e else: - e = None error = None end = time.time() message += f"\n-------Time used: {end - start} secs, roundtrip: {time_passed} secs -------" print(message, flush=True) # enable gc collection to prevent MemoryError by loading too many large models gc.collect() - return model_name, e + return model_name, error def green(text: str) -> str: From b467698b664d7726743a591d399728aad2de49fa Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:17:46 +0000 Subject: [PATCH 10/18] W0718 --- tools/ir/model_zoo_test/model_zoo_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 2fdaa8fc0..fc582776c 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -49,7 +49,7 @@ def run_one_test(model_info: hub.ModelInfo) -> tuple[str, Exception | None]: try: time_passed = test_model(model_info) message += green(f"\n[PASS]: {model_name} roundtrip test passed.") - except Exception as e: + except Exception as e: # noqa: W0718 time_passed = -1 stack_trace = traceback.format_exc() message += red(f"\n[FAIL]: {stack_trace}") From 50b9b8985d0226563c4ab07d9b71b318a1d9b7a3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:19:23 +0000 Subject: [PATCH 11/18] format --- tools/ir/model_zoo_test/model_zoo_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index fc582776c..8b198babc 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -97,11 +97,11 @@ def main(): failed_models.append(model_name) failed_messages.append((model_name, error)) if len(failed_models) == 0: - print(f"{len(model_list)} models have been checked.") + print(green(f"{len(model_list)} models have been checked.")) else: - print(f"In all {len(model_list)} models, {len(failed_models)} models failed") + print(red(f"In all {len(model_list)} models, {len(failed_models)} models failed")) for model_name, error in failed_messages: - print(f"{model_name} failed because: {error}") + print(f"{red(model_name)} failed because: {error}\n") sys.exit(1) From 5f3d6accff07f3bb69c8b0f526570ee2265815d5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:22:02 +0000 Subject: [PATCH 12/18] format --- tools/ir/model_zoo_test/model_zoo_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 8b198babc..7c0a06027 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -41,7 +41,7 @@ def test_model(model_info: hub.ModelInfo) -> float: return end - start -def run_one_test(model_info: hub.ModelInfo) -> tuple[str, Exception | None]: +def run_one_test(model_info: hub.ModelInfo) -> tuple[str, str | None]: start = time.time() model_name = model_info.model model_path = model_info.model_path @@ -51,9 +51,8 @@ def run_one_test(model_info: hub.ModelInfo) -> tuple[str, Exception | None]: message += green(f"\n[PASS]: {model_name} roundtrip test passed.") except Exception as e: # noqa: W0718 time_passed = -1 - stack_trace = traceback.format_exc() - message += red(f"\n[FAIL]: {stack_trace}") - error = e + error = traceback.format_exc() + message += red(f"\n[FAIL]: {e}") else: error = None end = time.time() @@ -100,8 +99,9 @@ def main(): print(green(f"{len(model_list)} models have been checked.")) else: print(red(f"In all {len(model_list)} models, {len(failed_models)} models failed")) - for model_name, error in failed_messages: - print(f"{red(model_name)} failed because: {error}\n") + for i, (model_name, error) in enumerate(failed_messages): + + print(f"[{i} / {len(failed_models)}] {red(model_name)} failed because: {error}\n") sys.exit(1) From 8c6b67e758a429e582d06494015e2e734f65e927 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:24:08 +0000 Subject: [PATCH 13/18] lint --- tools/ir/model_zoo_test/model_zoo_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 7c0a06027..976bc7404 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -100,7 +100,6 @@ def main(): else: print(red(f"In all {len(model_list)} models, {len(failed_models)} models failed")) for i, (model_name, error) in enumerate(failed_messages): - print(f"[{i} / {len(failed_models)}] {red(model_name)} failed because: {error}\n") sys.exit(1) From 847eafb1b5fef53ef8a43bdf2351735dee9575b6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 01:31:01 +0000 Subject: [PATCH 14/18] # pylint: disable=broad-exception-caught --- tools/ir/model_zoo_test/model_zoo_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 976bc7404..4e7550c3a 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -49,7 +49,7 @@ def run_one_test(model_info: hub.ModelInfo) -> tuple[str, str | None]: try: time_passed = test_model(model_info) message += green(f"\n[PASS]: {model_name} roundtrip test passed.") - except Exception as e: # noqa: W0718 + except Exception as e: # pylint: disable=broad-exception-caught time_passed = -1 error = traceback.format_exc() message += red(f"\n[FAIL]: {e}") From defea1fe7bb073066f1eca2fddcf3462f3a43bd7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 20:27:17 +0000 Subject: [PATCH 15/18] tempfile --- tools/ir/model_zoo_test/model_zoo_test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 4e7550c3a..187bcabf2 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -24,7 +24,10 @@ def test_model(model_info: hub.ModelInfo) -> float: model_name = model_info.model - with contextlib.redirect_stdout(None): + with tempfile.TemporaryDirectory() as temp_dir, contextlib.redirect_stdout(None): + # For parallel testing, this must be in a separate process because hub.set_dir + # is not thread-safe. + hub.set_dir(temp_dir) model = hub.load(model_name) assert model is not None onnx.checker.check_model(model) @@ -88,14 +91,12 @@ def main(): failed_models = [] failed_messages = [] # Use multi-processing to speed up the testing process - with tempfile.TemporaryDirectory() as temp_dir: - hub.set_dir(temp_dir) - results = multiprocessing.pool.Pool(args.jobs).map(run_one_test, model_list) + results = multiprocessing.pool.Pool(args.jobs).map(run_one_test, model_list) for model_name, error in results: if error is not None: failed_models.append(model_name) failed_messages.append((model_name, error)) - if len(failed_models) == 0: + if not failed_models: print(green(f"{len(model_list)} models have been checked.")) else: print(red(f"In all {len(model_list)} models, {len(failed_models)} models failed")) From 0239720ae157e236d6bdcaaf0a160846b8287791 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 20:45:26 +0000 Subject: [PATCH 16/18] snap --- tools/ir/model_zoo_test/model_zoo_test.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 187bcabf2..65e6c5e10 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -20,6 +20,7 @@ import onnxscript.testing from onnxscript import ir +import rich.progress def test_model(model_info: hub.ModelInfo) -> float: @@ -48,7 +49,7 @@ def run_one_test(model_info: hub.ModelInfo) -> tuple[str, str | None]: start = time.time() model_name = model_info.model model_path = model_info.model_path - message = f"----Testing: {model_name} @ {model_path}----" + message = f"\n----Testing: {model_name} @ {model_path}----" try: time_passed = test_model(model_info) message += green(f"\n[PASS]: {model_name} roundtrip test passed.") @@ -59,7 +60,7 @@ def run_one_test(model_info: hub.ModelInfo) -> tuple[str, str | None]: else: error = None end = time.time() - message += f"\n-------Time used: {end - start} secs, roundtrip: {time_passed} secs -------" + message += f"\n[Time]: {end - start} secs, roundtrip: {time_passed} secs" print(message, flush=True) # enable gc collection to prevent MemoryError by loading too many large models gc.collect() @@ -91,7 +92,14 @@ def main(): failed_models = [] failed_messages = [] # Use multi-processing to speed up the testing process - results = multiprocessing.pool.Pool(args.jobs).map(run_one_test, model_list) + from tqdm import tqdm + + with multiprocessing.pool.Pool(args.jobs) as pool: + results = list( + rich.progress.track( + pool.imap_unordered(run_one_test, model_list), "Testing...", total=len(model_list) + ) + ) for model_name, error in results: if error is not None: failed_models.append(model_name) @@ -101,7 +109,7 @@ def main(): else: print(red(f"In all {len(model_list)} models, {len(failed_models)} models failed")) for i, (model_name, error) in enumerate(failed_messages): - print(f"[{i} / {len(failed_models)}] {red(model_name)} failed because: {error}\n") + print(f"[{i} / {len(failed_models)}] {red(model_name)} failed because: {error}") sys.exit(1) From dc429d63908b32eb3d7516d3a7b50ee29297d7af Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 20:47:16 +0000 Subject: [PATCH 17/18] tqdm --- tools/ir/model_zoo_test/model_zoo_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 65e6c5e10..8a8d0a27e 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -16,11 +16,11 @@ import traceback import onnx +import tqdm from onnx import hub import onnxscript.testing from onnxscript import ir -import rich.progress def test_model(model_info: hub.ModelInfo) -> float: @@ -92,12 +92,12 @@ def main(): failed_models = [] failed_messages = [] # Use multi-processing to speed up the testing process - from tqdm import tqdm - with multiprocessing.pool.Pool(args.jobs) as pool: results = list( - rich.progress.track( - pool.imap_unordered(run_one_test, model_list), "Testing...", total=len(model_list) + tqdm.tqdm( + pool.imap_unordered(run_one_test, model_list), + "Testing...", + total=len(model_list), ) ) for model_name, error in results: From fe9a8c3470861b906b53593a23335f3d9a6fcc49 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 May 2024 20:50:40 +0000 Subject: [PATCH 18/18] filter --- tools/ir/model_zoo_test/model_zoo_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index 8a8d0a27e..de3410a49 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -77,6 +77,12 @@ def red(text: str) -> str: def main(): parser = argparse.ArgumentParser(description="Test IR roundtrip with ONNX model zoo.") + parser.add_argument( + "-k", + type=str, + default=None, + help="Keyword to filter the models. Default is None.", + ) parser.add_argument( "--jobs", type=int, @@ -86,6 +92,10 @@ def main(): args = parser.parse_args() model_list = hub.list_models() + if args.k: + # Filter the models by name + name = args.k.lower() + model_list = [model for model in model_list if name in model.model.lower()] print(f"=== Testing IR on {len(model_list)} models ===") # run checker on each model