Skip to content

Commit

Permalink
Merge branch 'microsoft:main' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium authored Apr 19, 2024
2 parents 2911a63 + 9001c69 commit 0e4ad28
Show file tree
Hide file tree
Showing 16 changed files with 83 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x64/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="python" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.14.0" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.14.1" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x86/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="pythonx86" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.14.0" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.14.1" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion cmake/external/dml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.14.0)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.14.1)

# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
Expand Down
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ Do not modify directly.*
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(float), tensor(float16)|
|||7+|**T** = tensor(float), tensor(float16)|
|LayerNormalization|*in* X:**T**<br> *in* Scale:**T**<br> *in* B:**T**<br> *out* Y:**T**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**<br><br>or<br><br>*in* X:**T**<br> *in* Scale:**V**<br> *in* B:**V**<br> *out* Y:**V**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**|17+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float)|
|||1+|**T** = tensor(float), tensor(float16)<br/> **V** = tensor(float), tensor(float16)|
|||1+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float), tensor(float16)<br/> **V** = tensor(float), tensor(float16)|
|LeakyRelu|*in* X:**T**<br> *out* Y:**T**|16+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
|Less|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
Expand Down Expand Up @@ -1224,6 +1224,7 @@ Do not modify directly.*
|||6+|**T** = tensor(float), tensor(float16)|
|Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float), tensor(float16)<br/> **V** = tensor(float), tensor(float16)|
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(float), tensor(float16)|
|Sinh|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
|Size|*in* data:**T**<br> *out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
Expand Down Expand Up @@ -1306,6 +1307,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
|**Operator Domain:** *com.microsoft.dml*||||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Dml
class DmlOperatorLayerNormalization : public DmlOperator
{
public:
DmlOperatorLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
DmlOperatorLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext, bool simplified)
: DmlOperator(kernelCreationContext)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2};
Expand Down Expand Up @@ -128,17 +128,18 @@ class DmlOperatorLayerNormalization : public DmlOperator
outputCastOpDesc.Desc = &outputCastDesc;
}

DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operatorDesc = {};
DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputCastOpDesc.Desc ? &inputCastOutputDmlTensorDesc : &inputDesc;
operatorDesc.ScaleTensor = scaleCastOpDesc.Desc ? &scaleCastOutputDmlTensorDesc : &scaleDesc;
operatorDesc.BiasTensor = biasCastOpDesc.Desc ? &biasCastOutputDmlTensorDesc : (biasDesc.Desc ? &biasDesc : nullptr);
operatorDesc.OutputTensor = outputCastOpDesc.Desc ? &outputCastOutputDmlTensorDesc : &outputDesc;
operatorDesc.Axes = onnxAxes.data();
operatorDesc.AxisCount = gsl::narrow_cast<uint32_t>(onnxAxes.size());
operatorDesc.NormalizeVariance = true;
operatorDesc.UseMean = !simplified;
operatorDesc.UseVariance = true;
operatorDesc.Epsilon = epsilon;
operatorDesc.FusedActivation = nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operatorDesc };
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2, &operatorDesc };

// Construct the graph
std::vector<const DML_OPERATOR_DESC*> opDescs;
Expand Down Expand Up @@ -258,7 +259,19 @@ void CALLBACK QueryLayerNormalization(IMLOperatorSupportQueryContextPrivate* con
*isSupported = context->GetOutputCount() == 1;
}

DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization, DmlOperatorLayerNormalization);
DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization17, DmlOperatorLayerNormalization);
// A specific type of operation for registration.
template <bool simplified>
class LayerNormalizationTemplate : public DmlOperatorLayerNormalization
{
public:
LayerNormalizationTemplate(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperatorLayerNormalization(kernelCreationContext, simplified)
{
}
};

DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization, LayerNormalizationTemplate<false>);
DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization17, LayerNormalizationTemplate<false>);
DML_OP_DEFINE_CREATION_FUNCTION(SimplifiedLayerNormalization, LayerNormalizationTemplate<true>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
namespace Dml
{

template <bool simplified>
class DmlOperatorSkipLayerNormalization : public DmlOperator
{
public:
Expand Down Expand Up @@ -83,17 +84,18 @@ class DmlOperatorSkipLayerNormalization : public DmlOperator
inputSkipBiasAddDesc.OutputTensor = &inputDesc;
DML_OPERATOR_DESC inputSkipBiasAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipBiasAddDesc };

DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC mvnDesc = {};
DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC mvnDesc = {};
mvnDesc.InputTensor = &inputDesc;
mvnDesc.ScaleTensor = &gammaDesc;
mvnDesc.BiasTensor = betaDesc.Desc ? &betaDesc : nullptr;
mvnDesc.OutputTensor = &outputDesc;
mvnDesc.Axes = axes.data();
mvnDesc.AxisCount = gsl::narrow_cast<uint32_t>(axes.size());
mvnDesc.NormalizeVariance = true;
mvnDesc.UseMean = !simplified;
mvnDesc.UseVariance = true;
mvnDesc.Epsilon = epsilon;
mvnDesc.FusedActivation = nullptr;
DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &mvnDesc };
DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2, &mvnDesc };

// Construct the graph
std::vector<const DML_OPERATOR_DESC*> opDescs;
Expand Down Expand Up @@ -223,6 +225,7 @@ void CALLBACK QuerySkipLayerNormalization(IMLOperatorSupportQueryContextPrivate*
*isSupported = true;
}

DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization);
DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization<false>);
DML_OP_DEFINE_CREATION_FUNCTION(SkipSimplifiedLayerNormalization, DmlOperatorSkipLayerNormalization<true>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(BiasAdd);
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(GroupNorm);
DML_OP_EXTERN_CREATION_FUNCTION(SimplifiedLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(SkipSimplifiedLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(RNN);
DML_OP_EXTERN_CREATION_FUNCTION(GRU);
Expand Down Expand Up @@ -548,7 +550,7 @@ constexpr static std::array<const char*, 2> typeNameListAttention = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListRotaryEmbedding = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
constexpr static std::array<const char*, 2> typeNameListLayerNorm = { "T", "U" };
constexpr static std::array<const char*, 2> typeNameListLayerNormContrib = { "T", "V" };
constexpr static std::array<const char*, 3> typeNameListLayerNormContrib = { "T", "U", "V" };
constexpr static std::array<const char*, 3> typeNameListThree = { "T1", "T2", "T3" };
constexpr static std::array<const char*, 4> typeNameListFour = { "T1", "T2", "T3", "T4" };
constexpr static std::array<const char*, 2> typeNameListTopK = { "T", "I" };
Expand Down Expand Up @@ -612,7 +614,7 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListIntege
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListArgMinMax = {SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerNormalizationContrib = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListLayerNormalizationContrib = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerNormalization = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
Expand Down Expand Up @@ -1110,7 +1112,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO( 7, SimplifiedLayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QuerySkipLayerNormalization)},
{REG_INFO_MS( 1, SkipSimplifiedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QuerySkipLayerNormalization)},
{REG_INFO_MS( 1, EmbedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, BiasSplitGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, BiasAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,9 @@ using ShapeInferenceHelper_GroupNorm = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_SkipLayerNormalization = SkipLayerNormHelper;
using ShapeInferenceHelper_SkipSimplifiedLayerNormalization = SkipLayerNormHelper;
using ShapeInferenceHelper_EmbedLayerNormalization = EmbedLayerNormalizationHelper;
using ShapeInferenceHelper_SimplifiedLayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_RNN = RecurrentHelper;
using ShapeInferenceHelper_GRU = RecurrentHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Upsample = 7;
static const int sc_sinceVer_Xor = 7;
static const int sc_sinceVer_LayerNormalization = 1;
static const int sc_sinceVer_SimplifiedLayerNormalization = 1;

// Special operators
static const int sc_sinceVer_MemcpyToHost = 1;
Expand Down Expand Up @@ -454,6 +455,7 @@ namespace OperatorHelper
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
static const int sc_sinceVer_SkipSimplifiedLayerNormalization = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;
static const int sc_sinceVer_BiasSplitGelu = 1;
static const int sc_sinceVer_NhwcConv = 1;
Expand Down
18 changes: 17 additions & 1 deletion onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ struct OrtVitisAIEpAPI {
const char* json_config);
std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>* (*compile_onnx_model_with_options)(
const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options);
uint32_t (*vaip_get_version)();
void Ensure() {
if (handle_)
return;
Expand All @@ -65,6 +66,8 @@ struct OrtVitisAIEpAPI {
::onnxruntime::LogRuntimeError(0, status1, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__);
ORT_THROW(status1);
}
std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version",
(void**)&vaip_get_version);
}

private:
Expand Down Expand Up @@ -177,8 +180,17 @@ void initialize_vitisai_ep() {
create_kernel_registry(s_domains_vitisaiep);
}

static void set_version_info(vaip_core::OrtApiForVaip& api) {
const char* magic = "VAIP";
std::memcpy(reinterpret_cast<char*>(&api.magic), magic, sizeof(api.magic));
api.major = 1u;
api.minor = 0u;
api.patch = 0u;
}

vaip_core::OrtApiForVaip* create_org_api_hook() {
InitProviderOrtApi();
set_version_info(the_global_api);
the_global_api.host_ = Provider_GetHost();
assert(Ort::Global<void>::api_ != nullptr);
the_global_api.ort_api_ = Ort::Global<void>::api_;
Expand Down Expand Up @@ -359,5 +371,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
the_global_api.get_lib_id = []() -> vaip_core::DllSafe<std::string> {
return vaip_core::DllSafe(std::string(GIT_COMMIT_ID));
};
return &the_global_api;
if (!s_library_vitisaiep.vaip_get_version) {
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
} else {
return &the_global_api;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ struct OrtApi;
namespace vaip_core {

struct OrtApiForVaip {
uint32_t magic; // 'VAIP' or something else to make sure the following field
// are not garbage.
uint32_t major; // bump this field changes that are not backward compatible or
// that represent a change in direction for the project
uint32_t minor; // bump this field for adding new features without breaking
// existing behavior
uint32_t patch; // bump this field for fixing some bugs but not introducing
// new functionality
onnxruntime::ProviderHost* host_;
const OrtApi* ort_api_;
// model
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/test/contrib_ops/layer_norm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ static void TestLayerNorm(const std::vector<int64_t>& x_dims,
// TODO keep_dims is not implemented, default behavior is to keep ones for reduced dimensions
ASSERT_NE(keep_dims, 0);

const std::vector<int64_t>& stats_dims = keep_dims ? n_and_ones_dims : n_dims;

CompareOpTester test(op.c_str(), opset);
test.AddAttribute("axis", axis);
test.AddAttribute("keep_dims", keep_dims);
Expand All @@ -65,16 +63,20 @@ static void TestLayerNorm(const std::vector<int64_t>& x_dims,
}

std::vector<float> Y_data = FillZeros<float>(n_x_m_dims);
test.AddOutput<float>("output", n_x_m_dims, Y_data);

#ifndef USE_DML
// DML doesn't support more than one output for these ops yet
const std::vector<int64_t>& stats_dims = keep_dims ? n_and_ones_dims : n_dims;
std::vector<float> mean_data = FillZeros<float>(stats_dims);
std::vector<float> var_data = FillZeros<float>(stats_dims);

test.AddOutput<float>("output", n_x_m_dims, Y_data);

// the Main and InvStdDev outputs are training specific
if (op.compare(SIMPLIFIED_LAYER_NORM_OP) != 0) {
test.AddOutput<float>("mean", stats_dims, mean_data);
}
test.AddOutput<float>("var", stats_dims, var_data);
#endif

#ifdef USE_CUDA
test.CompareWithCPU(kCudaExecutionProvider);
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,6 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_TokenCount) {
true);
}

// SkipSimplifiedLayerNorm has not been enabled for DML yet
#if !defined(USE_DML)
TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
int batch_size = 1;
int sequence_length = 2;
Expand Down Expand Up @@ -768,9 +766,8 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
true,
true);
}
#endif

#if !defined(USE_ROCM) && !defined(USE_DML)
#if !defined(USE_ROCM)
TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) {
int batch_size = 2;
int sequence_length = 2;
Expand Down
Loading

0 comments on commit 0e4ad28

Please sign in to comment.