Skip to content

Commit

Permalink
Add a simple test
Browse files Browse the repository at this point in the history
Fix typo
  • Loading branch information
wschin committed Mar 5, 2024
1 parent 060e09c commit f896fb8
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ def run_distributed_cache_test(cwd, log):
run_subprocess(command, cwd=cwd, log=log).check_returncode()


def run_aggressive_cpu_fallback_test(cwd, log):
log.debug("Running: ORTModule Cache Test")

command = [
"python3",
"orttraining_test_aggressive_cpu_fallback.py.py",
]

env = {"ORT_AGGRESSIVE_CPU_FALLBACK", 1}
run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode()


def main():
args = parse_arguments()
cwd = args.cwd
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import tempfile
import unittest

import onnx
import onnxscript
from onnxscript.onnx_opset import opset18
from onnxscript.onnx_types import FLOAT, INT64

import onnxruntime


class TestAggressiveCpuFallback(unittest.TestCase):
def test_reshape_matmul(self):
@onnxscript.script(default_opset=opset18)
def foo(x: FLOAT[12], w: FLOAT[6, 2], dim0: INT64[1], dim1: INT64[1]):
# This should be computed by CPU but is placed
# on CUDA (i.e., all inputs and outputs are GPU tensors).
dim2 = dim1 + 1
# Same as `dim2 = dim1 + 1`. Another GPU node.
dim3 = dim2 - 1
# Same as `dim2 = dim1 + 1`. Another GPU node.
new_shape = opset18.Concat(dim0, dim3, axis=0)
# A memcpy node will be inserted to copy GPU output
# to CPU since Reshape's 2nd input is a CPU tensor
# per schema definition.
#
# Use ORT_AGGRESSIVE_CPU_FALLBACK=1 to
# 1. remove memcpy node.
# 2. fallback all computation above this line to CPU.
new_x = opset18.Reshape(x, new_shape)
y = opset18.MatMul(new_x, w)
return y

model = foo.to_model_proto()

temp_file_name = tempfile.mktemp(prefix="cpu_fallback_test", suffix=".onnx")

Check failure

Code scanning / CodeQL

Insecure temporary file High test

Call to deprecated function tempfile.mktemp may be insecure.
session_options = onnxruntime.SessionOptions()
session_options.optimized_model_filepath = temp_file_name
# This call should trigger GetCpuPreferredNodes and then GetShapeRelatedNodes
# when environment variable ORT_AGGRESSIVE_CPU_FALLBACK=1 is set.
# As a result, no memcopy node should be observed in optimized graph.
#
# See comments inside `foo`.
onnxruntime.InferenceSession(
path_or_bytes=model.SerializeToString(), sess_options=session_options, providers=["CUDAExecutionProvider"]
)
optimized = onnx.load(temp_file_name)

self.assertTrue(all(node.op_type != "MemcpyToHost" for node in optimized.graph.node))
os.remove(temp_file_name)


if __name__ == "__main__":
unittest.main()

0 comments on commit f896fb8

Please sign in to comment.