diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp index cdc06b074b14a..51b19603f5122 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp @@ -51,7 +51,7 @@ // 3. Multiply elementwise every input tensor to compute the internal product. // 4. Sum reduce the product tensor to the final output shape, reducing along any missing dimensions. // So a product shape of [b,j,i,k] and output shape of [b,i,k] reduces along j. -// +// // ReduceSum( // Mul( // ExpandTransposeCollapseAsNeeded(A, aAxesToProductAxes), @@ -90,7 +90,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper uint32_t bindableInputCount = kernelCreationContext.GetInputCount(); if (IsMatMulOperatorType()) { - ++bindableInputCount; // Account for the optional C tensor. + ++bindableInputCount; // Account for the optional C tensor. } inputIndices.resize(bindableInputCount); @@ -231,8 +231,8 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper const DML_OPERATOR_DESC* operatorDescPointers[2] = { - &multiplyOperatorDescWithEnum, // NodeIndexMultiply - &reduceSumOperatorDescWithEnum, // NodeIndexReduceSum + &multiplyOperatorDescWithEnum, // NodeIndexMultiply + &reduceSumOperatorDescWithEnum, // NodeIndexReduceSum }; DML_INPUT_GRAPH_EDGE_DESC inputEdges[2]; @@ -280,9 +280,9 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper // Reproject all inputs and the output to the intermediate product tensor. // e.g. - // + // // Equation: i,j->ji - // + // // [1] [4,5,6,7] [4, 8,12] // [2] -> [5,10,15] // [3] [6,12,18] @@ -347,14 +347,14 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper // Set default sizes for shape compatibility with the product tensor, and // set strides to 0's initially to broadcast any missing dimensions. std::vector newSizes; - std::vector newStrides(newRank, 0u); // Default to 0 to broadcast missing entries. + std::vector newStrides(newRank, 0u); // Default to 0 to broadcast missing entries. if (isReduced) { - newSizes.resize(newRank, 1u); // Fill with 1's initially for any missing (reduced) dimensions. + newSizes.resize(newRank, 1u); // Fill with 1's initially for any missing (reduced) dimensions. } else { - newSizes = m_productDimensions; // Use the product tensor shape directly. Missing axes will be broadcasted. + newSizes = m_productDimensions; // Use the product tensor shape directly. Missing axes will be broadcasted. } // Scatter the original sizes and strides into the corresponding product tensor axis. @@ -364,7 +364,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper if (productAxis < newRank) { newSizes[productAxis] = originalSizes[i]; - newStrides[productAxis] += originalStrides[i]; // Add to combine diagonal cases like i,j,i->i,j + newStrides[productAxis] += originalStrides[i]; // Add to combine diagonal cases like i,j,i->i,j } } tensorDesc.SetDimensionsAndStrides(newSizes, newStrides); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h index f2fba8f2da455..f5660f0b14b7f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h @@ -52,7 +52,7 @@ namespace Dml // Rearranges existing m_sizes and m_strides by gathering axes from dimensionMapping. // It IS legal to change the number of dimensions by adding filler, dropping entire dimensions for a new view, - // and even duplicate logical dimensions. Axes beyond the original rank will be filled by size 1 and stride 0. + // and even duplicating logical dimensions. Axes beyond the original rank will be filled by size 1 and stride 0. // e.g. Existing sizes [2,3,4] with [2,0] yields [4,2]. void PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment);