Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Browse files Browse the repository at this point in the history
…zhanyi/tsaupload
  • Loading branch information
Yi Zhang committed Mar 28, 2024
2 parents 07272b8 + 55f63a4 commit 30a00ca
Show file tree
Hide file tree
Showing 44 changed files with 390 additions and 256 deletions.
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF)
# Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through
# OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical
# use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead.
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF)
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF)

option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF)
option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
endif()
if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
# cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and
# add to the lib onnxruntime_providers_cuda separatedly.
# added to the lib onnxruntime_providers_cuda separately.
# onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc.
set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc)
list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src})
Expand Down
7 changes: 7 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,13 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
if (MSVC)
# Cutlass code has an issue with the following:
# warning C4100: 'magic': unreferenced formal parameter
target_compile_options(onnxruntime_providers_cuda_ut PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4100>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd4100>")
endif()

list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut)
endif()

Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct NodeComputeInfo {
DestroyFunctionStateFunc release_state_func;
};

using RunOptions = OrtRunOptions;
using RunOptions = ::OrtRunOptions;

enum class DataLayout {
NCHW,
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ struct OrtRunOptions {
};

namespace onnxruntime {
using RunOptions = OrtRunOptions;
using RunOptions = ::OrtRunOptions;
} // namespace onnxruntime
12 changes: 11 additions & 1 deletion js/web/lib/onnxjs/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export class Model {
constructor() {}

load(buf: Uint8Array, graphInitializer?: Graph.Initializer, isOrtFormat?: boolean): void {
let onnxError: Error|undefined;
if (!isOrtFormat) {
// isOrtFormat === false || isOrtFormat === undefined
try {
Expand All @@ -25,10 +26,19 @@ export class Model {
if (isOrtFormat !== undefined) {
throw e;
}
onnxError = e;
}
}

this.loadFromOrtFormat(buf, graphInitializer);
try {
this.loadFromOrtFormat(buf, graphInitializer);
} catch (e) {
if (isOrtFormat !== undefined) {
throw e;
}
// Tried both formats and failed (when isOrtFormat === undefined)
throw new Error(`Failed to load model as ONNX format: ${onnxError}\nas ORT format: ${e}`);
}
}

private loadFromOnnxFormat(buf: Uint8Array, graphInitializer?: Graph.Initializer): void {
Expand Down
13 changes: 13 additions & 0 deletions js/web/test/e2e/browser-test-webgl.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,16 @@
it('Browser E2E testing - WebGL backend', async function() {
await testFunction(ort, {executionProviders: ['webgl']});
});

it('Browser E2E testing - invalid buffer', async () => {
try {
await ort.InferenceSession.create(
new Uint8Array(Array.from({length: 100}, () => 42)), {executionProviders: ['webgl']});

// Should not reach here.
assert(false);
} catch (e) {
assert(e.message.includes('as ONNX format'));
assert(e.message.includes('as ORT format'));
}
});
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,9 @@ inline __device__ float4 operator*(const float4 a, const float4 b) {
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
// TODO(wy): use cuda common header and investigate pipeline build issue.
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \
((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
inline __device__ half operator*(const half a, const half b) {
return __float2half(__half2float(a) * __half2float(b));
}
Expand All @@ -666,8 +668,10 @@ inline __device__ half2 operator*(const half2 a, const half2 b) {
}
#endif

// TODO(wy): use cuda common header and investigate pipeline build issue.
inline __device__ Half4 operator*(const Half4 a, const Half4 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \
((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
Half4 result;
result.x = a.x * b.x;
result.y = a.y * b.y;
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ struct BlockwiseQuantization {
static void prepack_weights(
int rows,
int columns,
const gsl::span<uint8_t const>& weights, // <- int4 weights, column major
const gsl::span<uint8_t>& weights_prepacked // <- int4 prepacked weights tensor, same size buffer
gsl::span<uint8_t const> weights, // <- int4 weights, column major
gsl::span<uint8_t> weights_prepacked // <- int4 prepacked weights tensor, same size buffer
) {
ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 &&
(rows % QuantBlocking::kRow) == 0 &&
Expand Down Expand Up @@ -171,10 +171,10 @@ struct BlockwiseQuantization {
static void prepack_quant_scales(
size_t rows,
size_t columns,
const gsl::span<ElementT const>& scales, // <- quant scales, column major layout
const gsl::span<ElementT>& scales_prepacked // <- quant scales prepacked, same size buffer
gsl::span<ElementT const> scales, // <- quant scales, column major layout
gsl::span<ElementT> scales_prepacked // <- quant scales prepacked, same size buffer
) {
auto meta_shape = get_quant_meta_shape(rows, columns);
auto meta_shape = get_quant_meta_shape(static_cast<int>(rows), static_cast<int>(columns));
ORT_ENFORCE(scales.size() == size_t(meta_shape.product()),
"Quantization scale tensor shape mismatch!");
ORT_ENFORCE(scales_prepacked.size() == size_t(meta_shape.product()),
Expand Down Expand Up @@ -241,10 +241,10 @@ struct BlockwiseQuantization {
static void prepack_quant_offsets(
size_t rows,
size_t columns,
const gsl::span<uint8_t const>& offsets, // <- quant offsets, int4, column major layout
const gsl::span<uint8_t>& offsets_prepacked // <- quant offsets prepacked, double size buffer
gsl::span<uint8_t const> offsets, // <- quant offsets, int4, column major layout
gsl::span<uint8_t> offsets_prepacked // <- quant offsets prepacked, double size buffer
) {
auto meta_shape = get_quant_meta_shape(rows, columns);
auto meta_shape = get_quant_meta_shape(static_cast<int>(rows), static_cast<int>(columns));

ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0,
"Does not support odd number of rows or columns!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct DummyType{
}

CUTLASS_HOST_DEVICE
std::monostate& operator[](int idx) {
std::monostate& operator[](int /*idx */) {
return dummy_;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,

CUTLASS_HOST_DEVICE
static void dequant(FragmentScale const &scales,
FragmentOffset const &offsets,
FragmentOffset const &fragment_offsets,
Array<uint8_t,kExpandedSize/2> const &weights,
Array<ElementScale, kExpandedSize>& dest){
static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm.");
Expand All @@ -453,19 +453,18 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,

uint32_t* dest_pair = reinterpret_cast<uint32_t*>(dest.data());
const b64* scales_ptr = reinterpret_cast<const b64*>(scales.data());
const ElementOffset* offsets_ptr = nullptr;
if constexpr(kHasOffset) { offsets_ptr = offsets.data(); }
[[maybe_unused]] const ElementOffset* fragment_offsets_ptr = nullptr;
if constexpr(kHasOffset) { fragment_offsets_ptr = fragment_offsets.data(); }

CUTLASS_PRAGMA_UNROLL
for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){
// dequantize: d = scale * (weight - offset)
// to use FMA, d = scale * weight + (scale * (-offset))

b64 offsets;
if constexpr(kHasOffset){
const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets_ptr);

[[maybe_unused]] b64 offsets{0};
if constexpr(kHasOffset) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
const uint32_t* p = reinterpret_cast<const uint32_t*>(fragment_offsets_ptr);
asm volatile(
"{\n\t"
" .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands
Expand All @@ -486,7 +485,7 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
assert(0);
#endif

offsets_ptr += 4;
fragment_offsets_ptr += 4;
} else {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
asm volatile(
Expand Down Expand Up @@ -541,7 +540,7 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma;
ElementScale s = scales[idx];
if constexpr(kHasOffset){
offset = s * static_cast<ElementScale>(-16 - int(offsets[idx]));
offset = s * static_cast<ElementScale>(-16 - static_cast<int>(fragment_offsets[idx]));
} else {
offset = s * static_cast<ElementScale>(-16-8);
}
Expand Down Expand Up @@ -795,13 +794,13 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
}
}
} else if constexpr (kMmaIterationsB % 2 == 0) {
const uint32_t* scales_ptr = reinterpret_cast<const uint32_t*>(scales.data());
uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);

if constexpr (kHasOffset){
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
const uint32_t* scales_ptr = reinterpret_cast<const uint32_t*>(scales.data());
uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);
// possible buffer over read 2 bytes here.
const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets.data());
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))

asm volatile(
"{\n\t"
" .reg .b32 rb0, rb1, rb2;\n"
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/gather_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
split_initializer_proto.add_dims(static_cast<int64_t>(split_values.size()));
split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end());
NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);
Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
Node& split_node = graph.AddNode(nodes_to_fuse[0].get().Name() + "/GatherSliceToSplitFusion/", "Split", "Split for Fused Gather nodes",
{graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs);
split_node.AddAttribute("axis", axis);
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/gemm_transpose_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
nodes_to_remove.push_back(output_node);
}

Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "_transformed"),
Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "/GemmTransposeFusion/"),
gemm_node.OpType(),
"Fused Gemm with Transpose",
new_gemm_input_defs,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}

InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale, bias};
Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("LayerNormalization"),
Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/LayerNormFusion/"),
"LayerNormalization",
"fused LayerNorm subgraphs ",
layer_norm_input_defs,
Expand Down Expand Up @@ -705,7 +705,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr

InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale};
Node& layer_norm_node =
graph.AddNode(graph.GenerateNodeName("SimplifiedLayerNormalization"), "SimplifiedLayerNormalization",
graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/SimplifiedLayerNormFusion/"), "SimplifiedLayerNormalization",
"fused LayerNorm subgraphs ", layer_norm_input_defs, {}, {}, kOnnxDomain);

// Get constant "epsilon" from "Add" node if available. Else, default value will be used.
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/matmul_scale_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ Status ProcessNode(
}

Node& matmul_scale_node = graph.AddNode(
graph.GenerateNodeName(node.Name() + "_FusedMatMulAndScale"),
graph.GenerateNodeName(node.Name() + "/MatMulScaleFusion/"),
"FusedMatMul",
"Fused MatMul and Scale",
fused_node_inputs,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/optimizer/matmul_transpose_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
const ONNX_NAMESPACE::TensorProto_DataType element_type =
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(cast_output->TypeAsProto()->tensor_type().elem_type());
new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type);
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto);
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "/MatmulTransposeFusion/", &new_cast_output_type_proto);

const std::array new_cast_input_defs{transpose_input};
const std::array new_cast_output_defs{&new_cast_output};
const std::array new_transpose_input_defs = {&new_cast_output};
const std::array new_transpose_output_defs = {cast_output};

Node& new_cast = graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"),
Node& new_cast = graph.AddNode(graph.GenerateNodeName(cast->Name() + "/MatmulTransposeFusion/"),
cast->OpType(),
"Created a new Cast node to interchange Cast and Transpose nodes",
new_cast_input_defs,
Expand Down Expand Up @@ -385,7 +385,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
const std::array input_defs{left_input, right_input};
const std::array output_defs{node.MutableOutputDefs()[0]};

Node& matmul_node = graph.AddNode(graph.GenerateNodeName("MatMul_With_Transpose"),
Node& matmul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "/MatmulTransposeFusion/"),
"FusedMatMul",
"fused MatMul and Transpose ",
input_defs,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/quick_gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,

NodeArg* quick_gelu_output_arg = mul_node.MutableOutputDefs()[0];
Node& quick_gelu_node =
graph.AddNode(graph.GenerateNodeName("QuickGelu"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg},
graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/QuickGeluFusion/"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg},
std::array{quick_gelu_output_arg}, {}, kMSDomain);
quick_gelu_node.AddAttribute("alpha", alpha);
quick_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,14 +394,6 @@ struct ConfigOptions final {
PROVIDER_DISALLOW_ALL(ConfigOptions)
};

struct OrtRunOptions final {
const ConfigOptions& GetConfigOptions() const {
return g_host->RunOptions__GetConfigOptions(this);
}

PROVIDER_DISALLOW_ALL(OrtRunOptions)
};

struct ComputeCapability final {
static std::unique_ptr<ComputeCapability> Create(std::unique_ptr<IndexedSubGraph> t_sub_graph) { return g_host->ComputeCapability__construct(std::move(t_sub_graph)); }
static void operator delete(void* p) { g_host->ComputeCapability__operator_delete(reinterpret_cast<ComputeCapability*>(p)); }
Expand Down Expand Up @@ -1283,3 +1275,10 @@ template <>
inline gsl::span<const int64_t> Tensor::DataAsSpan() const { return g_host->Tensor__DataAsSpan_int64(this); }

} // namespace onnxruntime

struct OrtRunOptions final {
const onnxruntime::ConfigOptions& GetConfigOptions() const {
return onnxruntime::g_host->RunOptions__GetConfigOptions(this);
}
PROVIDER_DISALLOW_ALL(OrtRunOptions)
};
2 changes: 1 addition & 1 deletion onnxruntime/core/util/matrix_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class MatrixRef {
MatrixRef(
NonConstMatrixRef const& ref, ///< MatrixRef to non-const data
/// SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const
_Magic magic = (typename std::enable_if<!IsNonConstRef, _Magic>::type)0
[[maybe_unused]] _Magic magic = (typename std::enable_if<!IsNonConstRef, _Magic>::type)0
) : data_(ref.data()), shape_(ref.shape()), layout_(Layout::packed(ref.shape())) {}

ORT_FORCEINLINE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


def qnn_preprocess_model(
model_input: Path,
model_output: Path,
model_input: str | Path | onnx.ModelProto,
model_output: str | Path,
fuse_layernorm: bool = False,
save_as_external_data: bool = False,
all_tensors_to_one_file: bool = False,
Expand All @@ -37,7 +37,7 @@ def qnn_preprocess_model(
- (Optional) Fuse ReduceMean sequence into a single LayerNormalization node.
Args:
model_input: Path to the input model file.
model_input: Path to the input model file or ModelProto.
model_output: Path the output model file, which is only created if this method returns True.
fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes.
Defaults to False.
Expand Down Expand Up @@ -82,7 +82,7 @@ def qnn_preprocess_model(
to cancel out.
"""
modified = False
model = onnx.load_model(model_input)
model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input)
onnx_model = ONNXModel(model)

# Fuse Erf sequence into a single Gelu
Expand Down
Loading

0 comments on commit 30a00ca

Please sign in to comment.