diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 987e063485846..725d11b9d54c5 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -61,16 +61,16 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow | ✓ | ✓ | | | PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | ✓ | ✓ | WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) | | Reciprocal | ai.onnx(7-12, 13+) | reciprocal | ✓ | ✓ | | -| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | ✗ | ✓ | Input 'axes' if present should be a constant | -| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | ✗ | ✓ | Input 'axes' if present should be a constant | -| ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum| ✗ | ✓ | Input 'axes' if present should be a constant | -| ReduceLogSumExp | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSumExp | ✗ | ✓ | Input 'axes' if present should be a constant | +| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | ✓ | ✓ | Input 'axes' if present should be a constant | +| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | ✓ | ✓ | Input 'axes' if present should be a constant | +| ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum| ✓ | ✓ | Input 'axes' if present should be a constant | +| ReduceLogSumExp | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSumExp | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceMax | ai.onnx(7-10, 11, 12, 13-17, 18-19, 20+) | reduceMax | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceMean | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceMean | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceMin | ai.onnx(7-10, 11, 12, 13-17, 18-19, 20+) | reduceMin | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceProd | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceProduct | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceSum | ai.onnx(7-10, 11-12, 13+) | reduceSum | ✓ | ✓ | Input 'axes' if present should be a constant | -| ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✗ | ✓ | Input 'axes' if present should be a constant | +| ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✓ | ✓ | Input 'axes' if present should be a constant | | Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | | Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes | diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 53c7f39bdd7f1..4298551aec412 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -71,7 +71,6 @@ #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rocm_blas_alt_impl.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/transpose_optimizer.h" @@ -215,9 +214,9 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } - // Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create - // more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output - // or consume different initializers with same value, by default, CSE will not merge them. + // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for + // CSE. For example, if A and B nodes consume different initializers with same value, by default, + // CSE will not merge them. InlinedHashSet excluded_initializers; excluded_initializers.reserve(session_options.initializers_to_share_map.size()); for (const auto& p : session_options.initializers_to_share_map) { @@ -225,7 +224,6 @@ InlinedVector> GenerateTransformers( } const InlinedHashSet no_limit_empty_ep_list = {}; transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); - transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq, session_options.config_options)); diff --git a/onnxruntime/core/optimizer/shape_input_merge.cc b/onnxruntime/core/optimizer/shape_input_merge.cc deleted file mode 100644 index dec1382319f16..0000000000000 --- a/onnxruntime/core/optimizer/shape_input_merge.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/optimizer/shape_input_merge.h" - -#include "core/graph/graph_utils.h" - -namespace onnxruntime { - -namespace { -std::string GetShapeString(const NodeArg* input_arg) { - auto shape = input_arg->Shape(); - if (!shape) return ""; - std::stringstream ss; - ss << "["; - for (int i = 0; i < shape->dim_size(); ++i) { - if (i != 0) ss << ","; - auto dim = shape->dim(i); - if (dim.has_dim_value()) { - ss << std::to_string(dim.dim_value()); - } else if (dim.has_dim_param()) { - ss << "'" << dim.dim_param() << "'"; - } else { - return ""; - } - } - ss << "]"; - return ss.str(); -} - -} // namespace - -Status ShapeInputMerge::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - InlinedHashMap> input_hash_to_nodes; - for (auto node_index : node_topology_list) { - auto* p_node = graph.GetNode(node_index); - if (!p_node) continue; // we removed the node as part of an earlier fusion - ORT_RETURN_IF_ERROR(Recurse(*p_node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(*p_node, "Shape", {1, 13, 15, 19, 21}) || - !graph_utils::IsSupportedProvider(*p_node, GetCompatibleExecutionProviders())) { - continue; - } - std::string shape_str = GetShapeString(p_node->InputDefs()[0]); - if (shape_str.empty()) continue; - if (input_hash_to_nodes.find(shape_str) == input_hash_to_nodes.end()) { - input_hash_to_nodes[shape_str] = InlinedVector(); - } - input_hash_to_nodes[shape_str].emplace_back(p_node); - } - - // All Shape nodes are processed in topological order, so we can safely merge the inputs to the first node's input. - for (auto& kv : input_hash_to_nodes) { - if (kv.second.size() < 2) continue; - NodeArg* first_input_arg = kv.second[0]->MutableInputDefs()[0]; - bool is_first_input_arg_graph_input = graph.IsInputsIncludingInitializers(first_input_arg); - for (size_t i = 1; i < kv.second.size(); ++i) { - Node* p_node = kv.second[i]; - const NodeArg* input_arg = p_node->InputDefs()[0]; - if (input_arg->Name() == first_input_arg->Name()) continue; - if (!graph.IsInputsIncludingInitializers(input_arg) && p_node->GetInputEdgesCount()) { - const Node::EdgeEnd& input_edge = *p_node->InputEdgesBegin(); - graph.RemoveEdge(input_edge.GetNode().Index(), p_node->Index(), input_edge.GetSrcArgIndex(), 0); - } - graph_utils::ReplaceNodeInput(*p_node, 0, *first_input_arg); - if (!is_first_input_arg_graph_input && kv.second[0]->GetInputEdgesCount()) { - const Node::EdgeEnd& first_input_edge = *kv.second[0]->InputEdgesBegin(); - graph.AddEdge(first_input_edge.GetNode().Index(), p_node->Index(), first_input_edge.GetSrcArgIndex(), 0); - } - modified = true; - } - } - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/shape_input_merge.h b/onnxruntime/core/optimizer/shape_input_merge.h deleted file mode 100644 index 5cb943998487b..0000000000000 --- a/onnxruntime/core/optimizer/shape_input_merge.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_transformer.h" - -namespace onnxruntime { - -/** -@Class ShapeInputMerge -Merge all shape inputs having same shape value to a single shape input. -This change will not affect the performance, but it open chances for CSE fusion to merge nodes. -*/ -class ShapeInputMerge : public GraphTransformer { - public: - ShapeInputMerge(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("ShapeInputMerge", compatible_execution_providers) {} - - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index cbea969904ed5..2fcf9a1d7d9ba 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -140,30 +140,44 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span> shape) { + std::optional> shape, bool convert_scalar = false) { tensor_type.set_datatype(data_type); if (shape) { - tensor_type.set_rank(shape->size()); - for (const auto& dim : *shape) { - if (dim >= 0) { - tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim)); - } else { - tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + auto rank = shape->size(); + if (convert_scalar && rank == 0) { + // CoreML scalar has shape {1} + tensor_type.set_rank(1); + tensor_type.add_dimensions()->mutable_constant()->set_size(1); + } else { + tensor_type.set_rank(rank); + for (const auto& dim : *shape) { + if (dim >= 0) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim)); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } } } } } void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type, - const ONNX_NAMESPACE::TensorShapeProto* shape) { + const ONNX_NAMESPACE::TensorShapeProto* shape, bool convert_scalar = false) { tensor_type.set_datatype(data_type); if (shape) { - tensor_type.set_rank(shape->dim_size()); - for (const auto& dim : shape->dim()) { - if (dim.has_dim_value()) { - tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim.dim_value())); - } else { - tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + auto rank = shape->dim_size(); + if (convert_scalar && rank == 0) { + // CoreML scalar has shape {1} + tensor_type.set_rank(1); + tensor_type.add_dimensions()->mutable_constant()->set_size(1); + } else { + tensor_type.set_rank(rank); + for (const auto& dim : shape->dim()) { + if (dim.has_dim_value()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim.dim_value())); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } } } } @@ -281,13 +295,13 @@ template MILSpec::Value CreateScalarTensorValue(const int32_t& data); template MILSpec::Value CreateScalarTensorValue(const std::string& data); template MILSpec::Value CreateScalarTensorValue(const bool& data); -COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg) { +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg, bool convert_scalar) { MILSpec::NamedValueType nvt; nvt.set_name(node_arg.Name()); MILSpec::TensorType& tensor_type = *nvt.mutable_type()->mutable_tensortype(); SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(node_arg.TypeAsProto()->tensor_type().elem_type()), - node_arg.Shape()); + node_arg.Shape(), convert_scalar); return nvt; } @@ -308,7 +322,7 @@ void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& outp MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()), - output.Shape()); + output.Shape(), /*convert_scalar*/ true); } void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type, diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 2804589065631..3e6c43ab07867 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -114,8 +114,10 @@ template COREML_SPEC::MILSpec::Value CreateScalarTensorValue(const T& data); /// Create a NamedValueType from an ONNX tensor NodeArg. +/// NodeArg to create NamedValueType from. +/// If true, scalar shapes are converted to 1D. /// Used to create inputs for the 'main' function in an ML Program. -COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg); +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg, bool convert_scalar = false); /// /// Add an input argument to a MILSpec::Operation diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index eb4723a3b9746..88b518ab2289c 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -838,13 +838,8 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (create_ml_program_) { if (is_input) { // the model inputs need to be wired up as args to the 'main' function. - auto tensor_value_type = CreateNamedTensorValueType(node_arg); + auto tensor_value_type = CreateNamedTensorValueType(node_arg, /*convert_scalar*/ true); tensor_value_type.set_name(name); - if (node_arg.Shape()->dim_size() == 0) { - // update shape from {} to {1} (same change we made at the model input level above). - tensor_value_type.mutable_type()->mutable_tensortype()->set_rank(1); - tensor_value_type.mutable_type()->mutable_tensortype()->add_dimensions()->mutable_constant()->set_size(1); - } mlprogram_main_fn_->mutable_inputs()->Add(std::move(tensor_value_type)); } else { diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 395a0b40e5bbb..4ee3f891f92ca 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -209,16 +209,16 @@ static const InlinedHashMap op_map = { {"Pow", {"pow", true}}, {"PRelu", {"prelu", true}}, {"Reciprocal", {"reciprocal", true}}, - {"ReduceL1", {"reduceL1", false}}, - {"ReduceL2", {"reduceL2", false}}, - {"ReduceLogSum", {"reduceLogSum", false}}, - {"ReduceLogSumExp", {"reduceLogSumExp", false}}, + {"ReduceL1", {"reduceL1", true}}, + {"ReduceL2", {"reduceL2", true}}, + {"ReduceLogSum", {"reduceLogSum", true}}, + {"ReduceLogSumExp", {"reduceLogSumExp", true}}, {"ReduceMax", {"reduceMax", true}}, {"ReduceMean", {"reduceMean", true}}, {"ReduceMin", {"reduceMin", true}}, {"ReduceProd", {"reduceProduct", true}}, {"ReduceSum", {"reduceSum", true}}, - {"ReduceSumSquare", {"reduceSumSquare", false}}, + {"ReduceSumSquare", {"reduceSumSquare", true}}, {"Relu", {"relu", true}}, {"Reshape", {"reshape", true}}, {"Resize", {"resample2d", true}}, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 9609ec57c8b26..f83fb8238ff61 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -70,7 +70,6 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" #include "core/optimizer/utils.h" @@ -7691,80 +7690,6 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } -TEST_F(GraphTransformationTests, ShapeInputMerge) { - auto build_test_case = [&](ModelTestBuilder& builder) { - std::vector> input_shape; - input_shape.reserve(5); - input_shape.emplace_back("dim0"); - input_shape.emplace_back(512); - input_shape.emplace_back(1); - input_shape.emplace_back(1536); - input_shape.emplace_back("dim4"); - auto* input_arg = builder.MakeSymbolicInput(input_shape); - auto* neg_out = builder.MakeIntermediate(); - auto* axes_initializer = builder.MakeInitializer({1}, {static_cast(2)}); - auto* squeeze_out = builder.MakeIntermediate(); - auto* cast_out = builder.MakeIntermediate(); - auto* unsqueeze_out = builder.MakeOutput(); - auto* shape_1_out = builder.MakeOutput(); - auto* shape_2_out = builder.MakeOutput(); - auto* shape_3_out = builder.MakeOutput(); - auto* shape_4_out = builder.MakeOutput(); - auto* shape_5_out = builder.MakeOutput(); - builder.AddNode("Neg", {input_arg}, {neg_out}); - builder.AddNode("Squeeze", {neg_out, axes_initializer}, {squeeze_out}); - builder.AddNode("Cast", {squeeze_out}, {cast_out}).AddAttribute("to", static_cast(10)); - builder.AddNode("Unsqueeze", {cast_out, axes_initializer}, {unsqueeze_out}); - builder.AddNode("Shape", {input_arg}, {shape_1_out}); - builder.AddNode("Shape", {neg_out}, {shape_2_out}); - builder.AddNode("Shape", {squeeze_out}, {shape_3_out}); - builder.AddNode("Shape", {cast_out}, {shape_4_out}); - builder.AddNode("Shape", {unsqueeze_out}, {shape_5_out}); - }; - - auto pre_graph_checker = [&](Graph& graph) { - InlinedHashMap ref_count; - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Shape") { - std::string name = node.InputDefs()[0]->Name(); - if (ref_count.find(name) == ref_count.end()) { - ref_count[name] = 1; - } else { - ref_count[name]++; - } - } - } - TEST_RETURN_IF_NOT(ref_count.size() == 5); - return Status::OK(); - }; - - auto post_graph_checker = [&](Graph& graph) { - InlinedHashMap ref_count; - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Shape") { - std::string name = node.InputDefs()[0]->Name(); - if (ref_count.find(name) == ref_count.end()) { - ref_count[name] = 1; - } else { - ref_count[name]++; - } - } - } - TEST_RETURN_IF_NOT(ref_count.size() == 2); - int sum = 0, mul = 1; - for (auto& entry : ref_count) { - sum += entry.second; - mul *= entry.second; - } - TEST_RETURN_IF_NOT(sum == 5 && mul == 6); - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); -} - #if !defined(DISABLE_CONTRIB_OPS) TEST_F(GraphTransformationTests, MatMulNBitsBiasFusion) { diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 436a24c34ddfd..589e7be455dbc 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -44,7 +44,6 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" @@ -117,11 +116,10 @@ std::vector> GeneratePreTrainingTransformers( ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); #endif - // Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create - // more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output - // or consume different initializers with same value, by default, CSE will not merge them. + // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for + // CSE. For example, if A and B nodes consume different initializers with same value, by default, + // CSE will not merge them. transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps)); // LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input. transformers.emplace_back(std::make_unique(compatible_eps)); // Remove duplicate nodes. Must be applied before any recompute transformations. diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index 88735ff18515e..35c5b736bd962 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -148,8 +148,8 @@ def test_onnx_ops(self): @unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support") def test_softmax_bf16_large(self): - if not torch.cuda.is_available(): - # only test bf16 on cuda + if torch.version.cuda is None: + # Only run this test when CUDA is available, as on ROCm BF16 is not supported by MIOpen. return class Model(torch.nn.Module): @@ -175,7 +175,7 @@ def forward(self, input): data_ort.requires_grad = True ort_res = ort_model(input=data_ort) ort_res.backward(gradient=init_grad) - # compara result + # compare result torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4)