Skip to content

Commit

Permalink
De-duplicate 1D scale and zero point tensors to scalars in DML kernels (
Browse files Browse the repository at this point in the history
#18862)

### Description
Cleanup and rebase from [this
PR](#18629)



### Motivation and Context

---------

Co-authored-by: Christian Larson <[email protected]>
Co-authored-by: Christian Larson <[email protected]>
Co-authored-by: Jeff Bloomfield <[email protected]>
Co-authored-by: Anagha Rao <[email protected]>
  • Loading branch information
5 people committed Jan 4, 2024
1 parent bdaeebd commit 70d3f68
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ namespace Windows::AI::MachineLearning::Adapter
{
uint32_t nodeCount = 0;
std::vector<std::unique_ptr<AbstractOperatorDesc>> nodesAsOperatorDesc;

// TODO (jeffbloo): Remove this
std::vector<Microsoft::WRL::ComPtr<IDMLOperator>> nodesAsIDMLOperator;

std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

namespace Dml
{

/*static*/ const uint32_t DmlOperator::zeroArray[8] = {};

DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo)
{
ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider));
Expand Down Expand Up @@ -824,4 +827,84 @@ namespace Dml
graphDesc.IntermediateEdges = dmlIntermediateEdges.data();
}

/*static*/ void DmlOperator::TryConvertTensorToBroadcastScalar(
const MLOperatorKernelCreationContext& kernelInfo,
const DML_TENSOR_DESC* tensor,
uint32_t kernelInputIndex)
{
if (!tensor)
{
return;
}

auto constExpTensor = kernelInfo.TryGetConstantCpuInputTensor(kernelInputIndex);
if (!constExpTensor)
{
return;
}
else if (!constExpTensor->IsCpuData())
{
return;
}

uint32_t totalKernelInputElementCount = constExpTensor->GetTotalElementCount();
if (totalKernelInputElementCount <= 1)
{
return;
}

uint32_t elementSize = 0;

switch (constExpTensor->GetTensorDataType())
{
case MLOperatorTensorDataType::UInt8:
case MLOperatorTensorDataType::Int8:
elementSize = 1;
break;

case MLOperatorTensorDataType::Float16:
case MLOperatorTensorDataType::UInt16:
case MLOperatorTensorDataType::Int16:
elementSize = 2;
break;

case MLOperatorTensorDataType::/*Float32*/Float:
case MLOperatorTensorDataType::UInt32:
case MLOperatorTensorDataType::Int32:
elementSize = 4;
break;

case MLOperatorTensorDataType::/*Float64*/Double:
case MLOperatorTensorDataType::UInt64:
case MLOperatorTensorDataType::Int64:
elementSize = 8;
break;

default:
return;
}

const std::uint8_t* byteData = static_cast<const std::uint8_t*>(constExpTensor->GetByteData());

assert(tensor->Type == DML_TENSOR_TYPE_BUFFER);
auto *bufferTensorDesc = const_cast<DML_BUFFER_TENSOR_DESC*>(static_cast<const DML_BUFFER_TENSOR_DESC*>(tensor->Desc));

for (size_t i = 1; i < totalKernelInputElementCount; ++i)
{
if (memcmp(byteData, byteData + i * elementSize, elementSize))
{
return;
}
}

if (bufferTensorDesc->DimensionCount > sizeof(zeroArray) / sizeof(zeroArray[0]))
{
assert(false);
return;
}

bufferTensorDesc->Strides = zeroArray;
bufferTensorDesc->TotalTensorSizeInBytes = (elementSize + 3) & ~3;
}

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ namespace Dml
uint32_t minDimensionCount = NchwDimensionCount
) const;

static void TryConvertTensorToBroadcastScalar(
const MLOperatorKernelCreationContext& kernelInfo,
const DML_TENSOR_DESC* tensor,
uint32_t kernelInputIndex);

private:
// For each input or output of the DML kernel, the corresponding input or output of the original
// kernel. Entries for unused DML inputs are nullopt.
Expand All @@ -164,6 +169,7 @@ namespace Dml
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);

static const uint32_t zeroArray[8];
};

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class DmlOperatorBatchNormalization15 : public DmlOperator, BatchNormalizationHe
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

// TODO (jeffbloo): Port this to a graph description to enable DML graph optimization

dml::Graph graph(m_dmlDevice.Get());
dml::TensorDesc inputTensorDesc = inputDescs[OnnxInputIndex::X];
dml::TensorDesc scaleTensorDesc = inputDescs[OnnxInputIndex::Scale];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,9 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
opDesc.ZeroPointTensor = &inputDescs[2];
opDesc.OutputTensor = &outputDescs[0];

TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1);
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2);

SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ namespace Dml

class DmlOperatorQLinearAdd : public DmlOperator
{
enum InputTensors {
IN_A,
enum InputTensors {
IN_A,
IN_A_SCALE,
IN_A_ZERO_POINT,
IN_B,
IN_A_ZERO_POINT,
IN_B,
IN_B_SCALE,
IN_B_ZERO_POINT,
IN_C_SCALE,
IN_C_ZERO_POINT
IN_C_SCALE,
IN_C_ZERO_POINT
};

public:
Expand Down Expand Up @@ -56,9 +56,18 @@ class DmlOperatorQLinearAdd : public DmlOperator
AddDesc.BScaleTensor = &inputDescs[IN_B_SCALE];
AddDesc.BZeroPointTensor = inputDescs[IN_B_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_B_ZERO_POINT] : nullptr;
AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE];
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
AddDesc.OutputTensor = &outputDescs[0];

TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AScaleTensor, IN_A_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AZeroPointTensor, IN_A_ZERO_POINT);

TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BScaleTensor, IN_B_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BZeroPointTensor, IN_B_ZERO_POINT);

TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputScaleTensor, IN_C_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputZeroPointTensor, IN_C_ZERO_POINT);

DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe
qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];
qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
Expand All @@ -129,6 +129,12 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe
qLinearAvgPooldesc.Dilations = m_kernel.dilations;
qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);

TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputScaleTensor, OrtInputTensors::ortInputScale);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputZeroPointTensor, OrtInputTensors::ortInputZeroPoint);

TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputScaleTensor, OrtInputTensors::ortOutputScale);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputZeroPointTensor, OrtInputTensors::ortOutputZeroPoint);

DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];

TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ScaleTensor, tupleStartIndex + 1);
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ZeroPointTensor, tupleStartIndex + 2);

dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
opDescs.push_back(&dmlOpDesc[inputIndex]);
Expand Down Expand Up @@ -154,6 +157,10 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];

TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::YScale);
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::YZeroPoint);

const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
opDescs.push_back(&opQuantizeDesc);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ class DmlOperatorQLinearConv : public DmlOperator, public ConvolutionHelperBase
convDesc.EndPadding = kernelArgs.endPadding;
convDesc.GroupCount = m_groupCount;

TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputScaleTensor, IN_X_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputZeroPointTensor, IN_X_ZERO_POINT);

TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterScaleTensor, IN_F_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterZeroPointTensor, IN_F_ZERO_POINT);

TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputScaleTensor, IN_Y_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);

DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, &convDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ class DmlOperatorQLinearMatMul : public DmlOperator
matMulDesc.OutputZeroPointTensor = inputDescs[IN_Y_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_Y_ZERO_POINT] : nullptr;
matMulDesc.OutputTensor = &outputDescs[0];

TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AScaleTensor, IN_A_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AZeroPointTensor, IN_A_ZERO_POINT);

TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BScaleTensor, IN_B_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BZeroPointTensor, IN_B_ZERO_POINT);

TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputScaleTensor, IN_Y_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);

DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, &matMulDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class DmlOperatorQLinearSigmoid : public DmlOperator
dequantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::X_scale];
dequantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::X_zero_point];
dequantizeOperatorDesc.OutputTensor = &namedIntermediateOutputTensorDesc;

TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ScaleTensor, OnnxInputIndex::X_scale);
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::X_zero_point);

const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDesc};

Expand All @@ -101,6 +104,10 @@ class DmlOperatorQLinearSigmoid : public DmlOperator
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::Y_scale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::Y_zero_point];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];

TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::Y_scale);
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::Y_zero_point);

const DML_OPERATOR_DESC opDesc3{DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};

MLOperatorGraphDesc operatorGraphDesc = {};
Expand Down

0 comments on commit 70d3f68

Please sign in to comment.