Skip to content

Commit

Permalink
Merge branch 'main' into tlwu/cudnn_flash_att
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Aug 9, 2024
2 parents f47d500 + 702b2e2 commit ccbbce8
Show file tree
Hide file tree
Showing 24 changed files with 279 additions and 57 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Condition="('$(OutputType)'!='Library' OR '$(IsAppExtension)'=='True')">
<NativeReference Include="$(MSBuildThisFileDirectory)..\..\runtimes\ios\native\onnxruntime.xcframework">
<NativeReference Include="$(MSBuildThisFileDirectory)..\..\runtimes\ios\native\onnxruntime.xcframework.zip">
<Kind>Static</Kind>
<IsCxx>True</IsCxx>
<SmartLink>True</SmartLink>
Expand All @@ -10,4 +10,4 @@
<WeakFrameworks>CoreML</WeakFrameworks>
</NativeReference>
</ItemGroup>
</Project>
</Project>
2 changes: 1 addition & 1 deletion js/build_jsep.bat
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ set BUILD_DIR=%ROOT%build_jsep
:arg1
if ["%~1"]==["d"] (
set CONFIG=Debug
set CONFIG_EXTRA_FLAG=--enable_wasm_debug_info --enable_wasm_profiling
set CONFIG_EXTRA_FLAG=--enable_wasm_debug_info --enable_wasm_profiling --cmake_extra_defines onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL=1
goto :arg2
)
if ["%~1"]==["r"] (
Expand Down
13 changes: 13 additions & 0 deletions js/web/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
],
"devDependencies": {
"@chiragrupani/karma-chromium-edge-launcher": "^2.2.2",
"@petamoriken/float16": "^3.8.7",
"@types/chai": "^4.3.4",
"@types/emscripten": "^1.39.6",
"@types/flatbuffers": "^1.10.0",
Expand Down
74 changes: 74 additions & 0 deletions js/web/test/data/ops/pad_f16.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
[
{
"name": "constant 2D float16",
"operator": "Pad",
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "constant", "type": "string" },
{ "name": "value", "data": 1.2, "type": "float" },
{ "name": "pads", "data": [3, 2, 2, 3], "type": "ints" }
],
"cases": [
{
"name": "[2,2]->[7,7]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.5],
"dims": [2, 2],
"type": "float16"
}
],
"outputs": [
{
"data": [
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.0, 2.0, 1.2, 1.2, 1.2, 1.2, 1.2, 3.0, 4.5, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2
],
"dims": [7, 7],
"type": "float16"
}
]
}
]
},
{
"name": "constant 2D float16",
"operator": "Pad",
"opset": { "domain": "", "version": 19 },
"attributes": [{ "name": "mode", "data": "constant", "type": "string" }],
"cases": [
{
"name": "[2,2]->[7,7]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.5],
"dims": [2, 2],
"type": "float16"
},
{
"data": [3, 2, 2, 3],
"dims": [4],
"type": "int64"
},
{
"data": [1.2],
"dims": [1],
"type": "float16"
}
],
"outputs": [
{
"data": [
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.0, 2.0, 1.2, 1.2, 1.2, 1.2, 1.2, 3.0, 4.5, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2
],
"dims": [7, 7],
"type": "float16"
}
]
}
]
}
]
4 changes: 4 additions & 0 deletions js/web/test/op-test-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@
"properties": {
"type": {
"enum": [
"float16",
"float32",
"float64",
"int8",
Expand Down Expand Up @@ -213,6 +214,7 @@
"properties": {
"type": {
"enum": [
"float16",
"float32",
"float64",
"int8",
Expand Down Expand Up @@ -247,6 +249,7 @@
"properties": {
"type": {
"enum": [
"float16",
"float32",
"float64",
"int8",
Expand Down Expand Up @@ -283,6 +286,7 @@
"properties": {
"type": {
"enum": [
"float16",
"float32",
"float64",
"int8",
Expand Down
29 changes: 27 additions & 2 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Float16Array as Float16ArrayPolyfill} from '@petamoriken/float16';
import {expect} from 'chai';
import * as ort from 'onnxruntime-common';
import {extname} from 'path';
Expand Down Expand Up @@ -391,6 +392,24 @@ export class TensorResultValidator {
case 'string':
return this.strictEqual(actual.data, expected.data);

case 'float16': {
const actualData = actual.data as Uint16Array;
const actualDataBuffer = actualData.buffer;
const actualDataByteOffset = actualData.byteOffset;
const actualDataLength = actualData.length;
const actualDataFloat32Array =
new Float32Array(new Float16ArrayPolyfill(actualDataBuffer, actualDataByteOffset, actualDataLength));

const expectedData = expected.data as Uint16Array;
const expectedDataBuffer = expectedData.buffer;
const expectedDataByteOffset = expectedData.byteOffset;
const expectedDataLength = expectedData.length;
const expectedDataFloat32Array =
new Float32Array(new Float16ArrayPolyfill(expectedDataBuffer, expectedDataByteOffset, expectedDataLength));

return this.floatEqual(actualDataFloat32Array, expectedDataFloat32Array);
}

case 'float32':
case 'float64':
return this.floatEqual(
Expand Down Expand Up @@ -919,11 +938,14 @@ async function runProtoOpTestcase(
const fetches: Record<string, Pick<ort.Tensor, 'dims'|'type'>> = {};
testCase.inputs.forEach((input, i) => {
if (input.data) {
let data: number[]|BigUint64Array|BigInt64Array = input.data;
let data: number[]|BigUint64Array|BigInt64Array|Uint16Array = input.data;
if (input.type === 'uint64') {
data = BigUint64Array.from(input.data.map(BigInt));
} else if (input.type === 'int64') {
data = BigInt64Array.from(input.data.map(BigInt));
} else if (input.type === 'float16') {
const dataArr = Float16ArrayPolyfill.from(input.data);
data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
}
feeds[`input_${i}`] = new ort.Tensor(input.type, data, input.dims);
}
Expand All @@ -933,11 +955,14 @@ async function runProtoOpTestcase(
const expectedOutputNames: string[] = [];
testCase.outputs.forEach((output, i) => {
if (output.data) {
let data: number[]|BigUint64Array|BigInt64Array = output.data;
let data: number[]|BigUint64Array|BigInt64Array|Uint16Array = output.data;
if (output.type === 'uint64') {
data = BigUint64Array.from(output.data.map(BigInt));
} else if (output.type === 'int64') {
data = BigInt64Array.from(output.data.map(BigInt));
} else if (output.type === 'float16') {
const dataArr = Float16ArrayPolyfill.from(output.data);
data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
}
outputs.push(new ort.Tensor(output.type, data, output.dims));
expectedOutputNames.push(`output_${i}`);
Expand Down
93 changes: 62 additions & 31 deletions onnxruntime/core/optimizer/pad_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,7 @@

namespace onnxruntime {

/*
* It matches following pattern:
* Pad
* |
* Conv/MaxPool/AveragePool
*/
bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
// if Pad has input axis, don't fuse it.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) ||
node.GetOutputEdgesCount() != 1 ||
node.InputDefs().size() > 3) {
return false;
}

if (graph.NodeProducesGraphOutput(node)) {
return false;
}

const Node& child_node = *node.OutputNodesBegin();
bool VerifyNotCastChild(const Node& child_node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
Expand Down Expand Up @@ -54,6 +36,45 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
return false;
}

return true;
}

void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints();
uint32_t child_pads_size = static_cast<uint32_t>(child_pads->size());

for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) {
child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]);
uint32_t mirrored_child_index = child_index + (child_pads_size / 2);
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
}
}
/*
* Before:
* Pad
* |
* Cast (Optional)
* |
* Conv/MaxPool/AveragePool
*
* After:
* Cast (Optional)
* |
* Conv/MaxPool/AveragePool
*/
bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
// if Pad has input axis, don't fuse it.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) ||
node.GetOutputEdgesCount() != 1 ||
node.InputDefs().size() > 3) {
return false;
}

if (graph.NodeProducesGraphOutput(node)) {
return false;
}

const NodeAttributes& pad_attributes = node.GetAttributes();
if (pad_attributes.find("mode") != pad_attributes.end() &&
pad_attributes.at("mode").s() != "constant") {
Expand Down Expand Up @@ -83,7 +104,19 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
}
}

return true;
const Node& child_node = *node.OutputNodesBegin();
if (graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Cast", {1, 6, 9, 13})) {
if (child_node.GetOutputEdgesCount() != 1) {
return false;
}

if (graph.NodeProducesGraphOutput(child_node)) {
return false;
}
return VerifyNotCastChild(*child_node.OutputNodesBegin());
} else {
return VerifyNotCastChild(child_node);
}
}

/*
Expand All @@ -100,8 +133,6 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end());
}

assert(static_cast<uint32_t>(pads_values.size()) == (2 * static_cast<uint32_t>(pad_node.InputDefs()[0]->Shape()->dim_size())));

uint32_t pads_size = static_cast<uint32_t>(pads_values.size());
// check if padding is applied only on feature dims
if (pads_values[0] != 0 || pads_values[1] != 0 || pads_values[pads_size / 2] != 0 ||
Expand All @@ -115,18 +146,18 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
}

Node& child_node = *graph.GetNode(pad_node.OutputNodesBegin()->Index());
auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints();
uint32_t child_pads_size = static_cast<uint32_t>(child_pads->size());

for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) {
child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]);
uint32_t mirrored_child_index = child_index + (child_pads_size / 2);
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
}
// We don't need to cast the pad_constant_value because this fusion requires that constant_pad_value
// to be zero. See PadFusion::SatisfyCondition for details.
Node& target_padding_node = (child_node.OpType() == "Cast") ? *graph.GetNode(child_node.OutputNodesBegin()->Index()) : child_node;
UpdatePaddingAttribute(target_padding_node, pads_values, pads_size);

graph_utils::RemoveNodeOutputEdges(graph, pad_node);
graph_utils::ReplaceNodeInput(child_node, 0, *pad_node.MutableInputDefs()[0]);
// Un-pad the output shape of Cast node
if (child_node.OpType() == "Cast") {
auto* cast_output_node_arg = child_node.MutableOutputDefs()[0];
cast_output_node_arg->SetShape(*pad_node.MutableInputDefs()[0]->Shape());
}
graph.RemoveNode(pad_node.Index());
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
nodes_falsenodeids.size() == nodes_values_as_tensor.size());
ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size());
ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size());
ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size());
ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size());
ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty());
ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty());
ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "migraphx_allocator.h"
#include "gpu_data_transfer.h"
#include "migraphx_inc.h"
#include <hip/hip_version.h>

#include "migraphx_stream_handle.h"

Expand Down Expand Up @@ -1299,7 +1300,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
if (!input_shape_match) {
if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) {
LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl;
#ifndef ENABLE_TRAINING_CORE
#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 2)
cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string());
#endif
#endif
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);

// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
Expand Down
Loading

0 comments on commit ccbbce8

Please sign in to comment.