Skip to content

Commit

Permalink
Polish and lint appeasement
Browse files Browse the repository at this point in the history
  • Loading branch information
fdwr committed Jun 20, 2024
1 parent 17b9461 commit b084e3e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
// 3. Multiply elementwise every input tensor to compute the internal product.
// 4. Sum reduce the product tensor to the final output shape, reducing along any missing dimensions.
// So a product shape of [b,j,i,k] and output shape of [b,i,k] reduces along j.
//
//
// ReduceSum(
// Mul(
// ExpandTransposeCollapseAsNeeded(A, aAxesToProductAxes),
Expand Down Expand Up @@ -90,7 +90,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
uint32_t bindableInputCount = kernelCreationContext.GetInputCount();
if (IsMatMulOperatorType())
{
++bindableInputCount; // Account for the optional C tensor.
++bindableInputCount; // Account for the optional C tensor.
}
inputIndices.resize(bindableInputCount);

Expand Down Expand Up @@ -231,8 +231,8 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper

const DML_OPERATOR_DESC* operatorDescPointers[2] =
{
&multiplyOperatorDescWithEnum, // NodeIndexMultiply
&reduceSumOperatorDescWithEnum, // NodeIndexReduceSum
&multiplyOperatorDescWithEnum, // NodeIndexMultiply
&reduceSumOperatorDescWithEnum, // NodeIndexReduceSum
};

DML_INPUT_GRAPH_EDGE_DESC inputEdges[2];
Expand Down Expand Up @@ -280,9 +280,9 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper

// Reproject all inputs and the output to the intermediate product tensor.
// e.g.
//
//
// Equation: i,j->ji
//
//
// [1] [4,5,6,7] [4, 8,12]
// [2] -> [5,10,15]
// [3] [6,12,18]
Expand Down Expand Up @@ -347,14 +347,14 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
// Set default sizes for shape compatibility with the product tensor, and
// set strides to 0's initially to broadcast any missing dimensions.
std::vector<uint32_t> newSizes;
std::vector<uint32_t> newStrides(newRank, 0u); // Default to 0 to broadcast missing entries.
std::vector<uint32_t> newStrides(newRank, 0u); // Default to 0 to broadcast missing entries.
if (isReduced)
{
newSizes.resize(newRank, 1u); // Fill with 1's initially for any missing (reduced) dimensions.
newSizes.resize(newRank, 1u); // Fill with 1's initially for any missing (reduced) dimensions.
}
else
{
newSizes = m_productDimensions; // Use the product tensor shape directly. Missing axes will be broadcasted.
newSizes = m_productDimensions; // Use the product tensor shape directly. Missing axes will be broadcasted.
}

// Scatter the original sizes and strides into the corresponding product tensor axis.
Expand All @@ -364,7 +364,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
if (productAxis < newRank)
{
newSizes[productAxis] = originalSizes[i];
newStrides[productAxis] += originalStrides[i]; // Add to combine diagonal cases like i,j,i->i,j
newStrides[productAxis] += originalStrides[i]; // Add to combine diagonal cases like i,j,i->i,j
}
}
tensorDesc.SetDimensionsAndStrides(newSizes, newStrides);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace Dml

// Rearranges existing m_sizes and m_strides by gathering axes from dimensionMapping.
// It IS legal to change the number of dimensions by adding filler, dropping entire dimensions for a new view,
// and even duplicate logical dimensions. Axes beyond the original rank will be filled by size 1 and stride 0.
// and even duplicating logical dimensions. Axes beyond the original rank will be filled by size 1 and stride 0.
// e.g. Existing sizes [2,3,4] with [2,0] yields [4,2].
void PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment);

Expand Down

0 comments on commit b084e3e

Please sign in to comment.