Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add arm64 bfloat16 fastmath mode option for transformers benchmarking script #19294

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions onnxruntime/python/tools/transformers/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
python benchmark.py -e torchscript onnxruntime -p "int8" -o
Run OnnxRuntime with the ROCM provider and graph optimization script:
python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm
Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support:
python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm

It is recommended to use run_benchmark.sh to launch benchmark.
"""
Expand Down Expand Up @@ -106,6 +108,7 @@ def run_onnxruntime(
use_raw_attention_mask,
model_fusion_statistics,
model_source,
enable_arm64_bfloat16_fastmath_mlas_gemm,
args,
):
import onnxruntime
Expand Down Expand Up @@ -209,6 +212,7 @@ def run_onnxruntime(
enable_all_optimization=True,
num_threads=num_threads,
verbose=verbose,
enable_mlas_gemm_fastmath_arm64_bfloat16=enable_arm64_bfloat16_fastmath_mlas_gemm,
)
if ort_session is None:
continue
Expand Down Expand Up @@ -764,6 +768,14 @@ def parse_arguments():
help="Manually set the model's layer number",
)

parser.add_argument(
"--enable_arm64_bfloat16_fastmath_mlas_gemm",
required=False,
action="store_true",
help="Enable bfloat16 mlas gemm kernels on aarch64. Supported only for CPU EP ",
)
parser.set_defaults(enable_arm64_bfloat16_fastmath_mlas_gemm=False)

FusionOptions.add_arguments(parser)

args = parser.parse_args()
Expand Down Expand Up @@ -909,6 +921,7 @@ def main():
use_raw_attention_mask,
model_fusion_statistics,
args.model_source,
args.enable_arm64_bfloat16_fastmath_mlas_gemm,
args,
)
except Exception:
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/python/tools/transformers/benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def create_onnxruntime_session(
num_threads=-1,
enable_profiling=False,
verbose=False,
enable_mlas_gemm_fastmath_arm64_bfloat16=False,
provider_options={}, # map execution provider name to its option # noqa: B006
):
session = None
Expand Down Expand Up @@ -136,6 +137,9 @@ def create_onnxruntime_session(
if provider_options:
providers = [(name, provider_options[name]) if name in provider_options else name for name in providers]

if enable_mlas_gemm_fastmath_arm64_bfloat16:
sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1")

session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
except Exception:
logger.error("Exception", exc_info=True)
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/python/tools/transformers/run_benchmark.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -------------------------------------------------------------------------

Check failure on line 1 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:1:1: error: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive. (ShellCheck.SC2148)
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
Expand Down Expand Up @@ -34,6 +34,9 @@
run_cpu_fp32=false
run_cpu_int8=false

# Set this to true to enable bfloat16 fastmath gemm kernels on aarch64 platforms with bfloat16 support
arm64_bfloat16_fastmath_mode=false

average_over=1000
# CPU takes longer time to run, only run 100 inferences to get average latency.
if [ "$run_cpu_fp32" = true ] || [ "$run_cpu_int8" = true ]; then
Expand Down Expand Up @@ -63,7 +66,7 @@
# export CUDA_VISIBLE_DEVICES=1

# This script will generate a logs file with a list of commands used in tests.
echo echo "ort=$run_ort torch=$run_torch torch2=$run_torch2 torchscript=$run_torchscript tensorflow=$run_tensorflow gpu_fp32=$run_gpu_fp32 gpu_fp16=$run_gpu_fp16 cpu=$run_cpu optimizer=$use_optimizer batch=$batch_sizes sequence=$sequence_length models=$models_to_test" >> benchmark.log
echo echo "ort=$run_ort torch=$run_torch torch2=$run_torch2 torchscript=$run_torchscript tensorflow=$run_tensorflow gpu_fp32=$run_gpu_fp32 gpu_fp16=$run_gpu_fp16 cpu=$run_cpu optimizer=$use_optimizer batch=$batch_sizes sequence=$sequence_length models=$models_to_test" arm64_bfloat16_fastmath_mode=$arm64_bfloat16_fastmath_mode >> benchmark.log

Check warning on line 69 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 run_cpu is referenced but not assigned. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:69:167: warning: run_cpu is referenced but not assigned. (ShellCheck.SC2154)

Check warning on line 69 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 sequence_length is referenced but not assigned (did you mean 'sequence_lengths'?). Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:69:229: warning: sequence_length is referenced but not assigned (did you mean 'sequence_lengths'?). (ShellCheck.SC2154)

# Set it to false to skip testing. You can use it to dry run this script with the log file.
run_tests=true
Expand Down Expand Up @@ -127,11 +130,15 @@
benchmark_options="$benchmark_options --force_num_layers $layer_number"
fi

if [ "$arm64_bfloat16_fastmath_mode" = true ] ; then
benchmark_options="$benchmark_options --enable_arm64_bfloat16_fastmath_mlas_gemm"
fi

# -------------------------------------------
run_one_test() {
if [ "$run_ort" = true ] ; then
echo python $benchmark_script -m $1 $onnx_export_options $2 $3 $4 >> benchmark.log

Check warning on line 140 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 Double quote to prevent globbing and word splitting. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:140:19: info: Double quote to prevent globbing and word splitting. (ShellCheck.SC2086)

Check warning on line 140 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 Double quote to prevent globbing and word splitting. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:140:40: info: Double quote to prevent globbing and word splitting. (ShellCheck.SC2086)

Check warning on line 140 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 Double quote to prevent globbing and word splitting. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:140:43: info: Double quote to prevent globbing and word splitting. (ShellCheck.SC2086)

Check warning on line 140 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 Double quote to prevent globbing and word splitting. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:140:64: info: Double quote to prevent globbing and word splitting. (ShellCheck.SC2086)

Check warning on line 140 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 Double quote to prevent globbing and word splitting. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:140:67: info: Double quote to prevent globbing and word splitting. (ShellCheck.SC2086)

Check warning on line 140 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 Double quote to prevent globbing and word splitting. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:140:70: info: Double quote to prevent globbing and word splitting. (ShellCheck.SC2086)
echo python $benchmark_script -m $1 $benchmark_options $2 $3 $4 -i $input_counts >> benchmark.log

Check warning on line 141 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 Double quote to prevent globbing and word splitting. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:141:19: info: Double quote to prevent globbing and word splitting. (ShellCheck.SC2086)

Check warning on line 141 in onnxruntime/python/tools/transformers/run_benchmark.sh

View workflow job for this annotation

GitHub Actions / Optional Lint

[shellcheck] reported by reviewdog 🐶 Double quote to prevent globbing and word splitting. Raw Output: ./onnxruntime/python/tools/transformers/run_benchmark.sh:141:40: info: Double quote to prevent globbing and word splitting. (ShellCheck.SC2086)
if [ "$run_tests" = true ] ; then
python $benchmark_script -m $1 $onnx_export_options $2 $3 $4
python $benchmark_script -m $1 $benchmark_options $2 $3 $4 -i $input_counts
Expand Down
Loading