Skip to content

Commit

Permalink
DML EP EinSum extend more generically
Browse files Browse the repository at this point in the history
  • Loading branch information
fdwr committed Jun 19, 2024
1 parent fff68c3 commit f3a4bee
Show file tree
Hide file tree
Showing 8 changed files with 410 additions and 260 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataTyp
};
}

bool IsSigned(DML_TENSOR_DATA_TYPE dataType)
bool IsSigned(DML_TENSOR_DATA_TYPE dataType) noexcept
{
switch (dataType)
{
Expand Down Expand Up @@ -140,6 +140,32 @@ uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
return deviceTypeMask;
}

uint32_t GetBitMaskFromIndices(gsl::span<const uint32_t> indices) noexcept
{
uint32_t bitMask = 0;
for (auto i : indices)
{
assert(i < 32);
bitMask |= (1 << i);
}
return bitMask;
}

uint32_t CountLeastSignificantZeros(uint32_t value) noexcept
{
// *Use std::countr_zero instead when codebase updated to C++20.
// Use bit twiddling hack rather than for loop.
uint32_t count = 32;
value &= -int32_t(value);
if (value) count--;
if (value & 0x0000FFFF) count -= 16;
if (value & 0x00FF00FF) count -= 8;
if (value & 0x0F0F0F0F) count -= 4;
if (value & 0x33333333) count -= 2;
if (value & 0x55555555) count -= 1;
return count;
}

void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides)
{
assert(sizes.size() == strides.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ namespace Dml
size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType);
size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor);
uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice);
uint32_t GetBitMaskFromIndices(gsl::span<const uint32_t> indices) noexcept;
uint32_t CountLeastSignificantZeros(uint32_t value) noexcept;
void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides);

bool IsSigned(DML_TENSOR_DATA_TYPE dataType);
bool IsSigned(DML_TENSOR_DATA_TYPE dataType) noexcept;

template <typename T>
void CastToClampedScalarUnion(DML_TENSOR_DATA_TYPE dataType, T value, DML_SCALAR_UNION* outputValue)
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,15 @@ gsl::span<const uint32_t> TensorDesc::GetStrides() const

void TensorDesc::SetStrides(gsl::span<const uint32_t> strides)
{
m_bufferTensorDesc.Strides = strides.empty() ? nullptr : strides.data();

if (!strides.empty())
{
ML_CHECK_VALID_ARGUMENT(strides.size() <= std::size(m_strides));
ML_CHECK_VALID_ARGUMENT(strides.size() == m_bufferTensorDesc.DimensionCount);
std::copy(strides.begin(), strides.end(), m_strides);
}

m_bufferTensorDesc.Strides = strides.empty() ? nullptr : m_strides;

m_bufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize(
m_bufferTensorDesc.DataType,
m_bufferTensorDesc.DimensionCount,
Expand Down Expand Up @@ -289,6 +289,15 @@ void TensorDesc::ForceUnsignedDataType()
}
}

// Add additional padding 1's to ensure the count is at least that large.
void TensorDesc::EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment)
{
if (m_bufferTensorDesc.DimensionCount < newDimensionCount)
{
SetDimensionCount(newDimensionCount, alignment);
}
}

void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment)
{
ML_CHECK_VALID_ARGUMENT(newDimensionCount <= MaximumDimensionCount);
Expand Down Expand Up @@ -321,20 +330,32 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
m_bufferTensorDesc.DimensionCount = newDimensionCount;
}

// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout
void TensorDesc::SetDimensionsAndStrides(gsl::span<const uint32_t> sizes, gsl::span<const uint32_t> strides)
{
static_assert(sizeof(m_sizes) == sizeof(m_strides));
ML_CHECK_VALID_ARGUMENT(sizes.size() <= std::size(m_sizes));
ML_CHECK_VALID_ARGUMENT(strides.empty() || strides.size() == sizes.size());

std::copy(sizes.begin(), sizes.end(), m_sizes);
m_bufferTensorDesc.DimensionCount = static_cast<uint32_t>(sizes.size());
SetStrides(strides);
}

void TensorDesc::PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment)
{
const uint32_t oldRank = m_bufferTensorDesc.DimensionCount;
EnsureStridesExist();
SetDimensionCount(static_cast<uint32_t>(dimensionMapping.size()), alignment);

// Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping
std::vector<uint32_t> tempSizes{m_sizes, m_sizes + MaximumDimensionCount};
std::vector<uint32_t> tempStrides{m_strides, m_strides + MaximumDimensionCount};
// Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping.
std::vector<uint32_t> oldSizes{m_sizes, m_sizes + MaximumDimensionCount};
std::vector<uint32_t> oldStrides{m_strides, m_strides + MaximumDimensionCount};

for (size_t i = 0; i < dimensionMapping.size(); i++)
{
m_sizes[i] = tempSizes[dimensionMapping[i]];
m_strides[i] = tempStrides[dimensionMapping[i]];
uint32_t sourceAxis = dimensionMapping[i];
m_sizes[i] = sourceAxis < oldRank ? oldSizes[sourceAxis] : 1;
m_strides[i] = sourceAxis < oldRank ? oldStrides[sourceAxis] : 0;
}

m_bufferTensorDesc.Sizes = m_sizes;
Expand All @@ -345,14 +366,10 @@ void TensorDesc::EnsureStridesExist()
{
if (m_bufferTensorDesc.Strides != nullptr)
{
// Strides are populated
// Strides are already populated
return;
}

uint32_t stride = 1;
for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;)
{
m_strides[i] = stride;
stride *= m_sizes[i];
}
GetDescendingPackedStrides({m_sizes, m_bufferTensorDesc.DimensionCount}, {m_strides, m_bufferTensorDesc.DimensionCount});
m_bufferTensorDesc.Strides = m_strides;
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,19 @@ namespace Dml
inline bool IsValid() const { return m_tensorType != DML_TENSOR_TYPE_INVALID; }
inline uint32_t GetDimensionCount() const { return m_bufferTensorDesc.DimensionCount; }
void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
void EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);

gsl::span<const uint32_t> GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
gsl::span<const uint32_t> GetStrides() const;
void SetStrides(gsl::span<const uint32_t> strides);
void EnsureStridesExist();

void SetDimensionsAndStrides(gsl::span<const uint32_t> sizes, gsl::span<const uint32_t> strides);

// Rearranges existing m_sizes and m_strides by gathering axes from dimensionMapping.
// It IS legal to change the number of dimensions by adding filler, dropping entire dimensions for a new view,
// and even duplicate logical dimensions. Axes beyond the original rank will be filled by size 1 and stride 0.
// e.g. Existing sizes [2,3,4] with [2,0] yields [4,2].
void PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment);

inline uint64_t GetBufferSizeInBytes() const
Expand Down Expand Up @@ -91,8 +101,6 @@ namespace Dml
uint32_t m_sizes[MaximumDimensionCount] = {};
uint32_t m_strides[MaximumDimensionCount] = {};
DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {};

void EnsureStridesExist();
};

class TensorDescBuilder
Expand Down
Loading

0 comments on commit f3a4bee

Please sign in to comment.