From f2c62b3452a936c282ae501d4e1fe25f592a8365 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 24 Nov 2023 09:41:40 -0500 Subject: [PATCH] Fix bisection yaml --- .../userbenchmark-a100-bisection.yml | 2 +- userbenchmark/test_bench/run.py | 41 +++++++++++++++---- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/.github/workflows/userbenchmark-a100-bisection.yml b/.github/workflows/userbenchmark-a100-bisection.yml index 8ca7b473de..b7fb3fdc67 100644 --- a/.github/workflows/userbenchmark-a100-bisection.yml +++ b/.github/workflows/userbenchmark-a100-bisection.yml @@ -23,7 +23,7 @@ jobs: CONDA_ENV: "bisection-ci-a100" PLATFORM_NAME: "gcp_a100" SETUP_SCRIPT: "/workspace/setup_instance.sh" - BISECT_WORKDIR: ".userbenchmark/${{ github.env.userbenchmark }}/bisection" + BISECT_WORKDIR: ".userbenchmark/${{ github.event.inputs.userbenchmark }}/bisection" AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} if: ${{ github.repository_owner == 'pytorch' }} diff --git a/userbenchmark/test_bench/run.py b/userbenchmark/test_bench/run.py index cd81195075..f1f361b76e 100644 --- a/userbenchmark/test_bench/run.py +++ b/userbenchmark/test_bench/run.py @@ -7,6 +7,9 @@ import json import os import shutil +import yaml +import re +import ast import numpy from typing import List, Dict, Optional, Any, Union @@ -25,6 +28,18 @@ def config_to_str(config: TorchBenchModelConfig) -> str: f" bs={config.batch_size}, extra_args={config.extra_args}" return metrics_base +def str_to_config(metric_name: str) -> TorchBenchModelConfig: + regex = "model=(.*), test=(.*), device=(.*), bs=(.*), extra_args=(.*), metric=(.*)" + model, test, device, batch_size, extra_args, _metric = re.match(regex, metric_name).groups() + extra_args = ast.literal_eval(extra_args) + return TorchBenchModelConfig( + name=model, + test=test, + device=device, + batch_size=batch_size, + extra_args=extra_args, + ) + def generate_model_configs(devices: List[str], tests: List[str], batch_sizes: List[str], model_names: List[str], extra_args: List[str]) -> List[TorchBenchModelConfig]: """Use the default batch size and default mode.""" if not model_names: @@ -40,6 +55,12 @@ def generate_model_configs(devices: List[str], tests: List[str], batch_sizes: Li ) for device, test, batch_size, model_name in cfgs] return result +def generate_model_configs_from_bisect_yaml(bisect_yaml: str) -> List[TorchBenchModelConfig]: + with open(bisect_yaml, "r") as fp: + bisect = yaml.safe_load(fp) + result = list(map(lambda x: str_to_config(x), bisect["details"].keys())) + return result + def init_output_dir(configs: List[TorchBenchModelConfig], output_dir: pathlib.Path) -> List[TorchBenchModelConfig]: result = [] for config in configs: @@ -118,6 +139,7 @@ def parse_known_args(args): default_device = "cuda" if "cuda" in list_devices() else "cpu" parser.add_argument( "models", + nargs="*", help="Name of models to run, split by comma.", ) parser.add_argument("--device", "-d", default=default_device, help="Devices to run, splited by comma.") @@ -132,15 +154,18 @@ def parse_known_args(args): def run(args: List[str]): args, extra_args = parse_known_args(args) - # If not specified, use the entire model set - if not args.models: - args.models = list_models() + if args.run_bisect: + configs = generate_model_configs_from_bisect_yaml(args.run_bisect) + else: + # If not specified, use the entire model set + if not args.models: + args.models = list_models() + devices = validate(parse_str_to_list(args.device), list_devices()) + tests = validate(parse_str_to_list(args.test), list_tests()) + batch_sizes = parse_str_to_list(args.bs) + models = validate(parse_str_to_list(args.models), list_models()) + configs = generate_model_configs(devices, tests, batch_sizes, model_names=models, extra_args=extra_args) debug_output_dir = get_default_debug_output_dir(args.output) if args.debug else None - devices = validate(parse_str_to_list(args.device), list_devices()) - tests = validate(parse_str_to_list(args.test), list_tests()) - batch_sizes = parse_str_to_list(args.bs) - models = validate(parse_str_to_list(args.models), list_models()) - configs = generate_model_configs(devices, tests, batch_sizes, model_names=models, extra_args=extra_args) configs = init_output_dir(configs, debug_output_dir) if debug_output_dir else configs results = {} try: