Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DML EP EinSum make more generic to avoid EP fallback #21114

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataTyp
};
}

bool IsSigned(DML_TENSOR_DATA_TYPE dataType)
bool IsSigned(DML_TENSOR_DATA_TYPE dataType) noexcept
{
switch (dataType)
{
Expand Down Expand Up @@ -140,7 +140,33 @@ uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
return deviceTypeMask;
}

void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides)
uint32_t GetBitMaskFromIndices(gsl::span<const uint32_t> indices) noexcept
{
uint32_t bitMask = 0;
for (auto i : indices)
{
assert(i < 32);
bitMask |= (1 << i);
}
return bitMask;
}

uint32_t CountLeastSignificantZeros(uint32_t value) noexcept
github-advanced-security[bot] marked this conversation as resolved.
Dismissed
Show resolved Hide resolved
{
// *Use std::countr_zero instead when codebase updated to C++20.
// Use bit twiddling hack rather than for loop.
uint32_t count = 32;
value &= -int32_t(value);
if (value) count--;
if (value & 0x0000FFFF) count -= 16;
if (value & 0x00FF00FF) count -= 8;
if (value & 0x0F0F0F0F) count -= 4;
if (value & 0x33333333) count -= 2;
if (value & 0x55555555) count -= 1;
return count;
}

void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides) noexcept
{
assert(sizes.size() == strides.size());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ namespace Dml
size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType);
size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor);
uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice);
void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides);
uint32_t GetBitMaskFromIndices(gsl::span<const uint32_t> indices) noexcept;
uint32_t CountLeastSignificantZeros(uint32_t value) noexcept;
void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides) noexcept;

bool IsSigned(DML_TENSOR_DATA_TYPE dataType);
bool IsSigned(DML_TENSOR_DATA_TYPE dataType) noexcept;

template <typename T>
void CastToClampedScalarUnion(DML_TENSOR_DATA_TYPE dataType, T value, DML_SCALAR_UNION* outputValue)
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ TensorDesc::TensorDesc(
assert(m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType));
}

gsl::span<const uint32_t> TensorDesc::GetStrides() const
gsl::span<const uint32_t> TensorDesc::GetStrides() const noexcept
{
if (m_bufferTensorDesc.Strides == nullptr)
{
Expand All @@ -212,23 +212,23 @@ gsl::span<const uint32_t> TensorDesc::GetStrides() const

void TensorDesc::SetStrides(gsl::span<const uint32_t> strides)
{
m_bufferTensorDesc.Strides = strides.empty() ? nullptr : strides.data();
fdwr marked this conversation as resolved.
Show resolved Hide resolved

if (!strides.empty())
{
ML_CHECK_VALID_ARGUMENT(strides.size() <= std::size(m_strides));
ML_CHECK_VALID_ARGUMENT(strides.size() == m_bufferTensorDesc.DimensionCount);
std::copy(strides.begin(), strides.end(), m_strides);
}

m_bufferTensorDesc.Strides = strides.empty() ? nullptr : m_strides;

m_bufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize(
m_bufferTensorDesc.DataType,
m_bufferTensorDesc.DimensionCount,
m_sizes,
strides.empty() ? nullptr : m_strides);
}

DML_TENSOR_DESC TensorDesc::GetDmlDesc()
DML_TENSOR_DESC TensorDesc::GetDmlDesc() noexcept
{
if (m_tensorType == DML_TENSOR_TYPE_INVALID)
{
Expand Down Expand Up @@ -289,6 +289,15 @@ void TensorDesc::ForceUnsignedDataType()
}
}

// Add additional padding 1's to ensure the count is at least that large.
void TensorDesc::EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment)
{
if (m_bufferTensorDesc.DimensionCount < newDimensionCount)
{
SetDimensionCount(newDimensionCount, alignment);
}
}

void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment)
{
ML_CHECK_VALID_ARGUMENT(newDimensionCount <= MaximumDimensionCount);
Expand Down Expand Up @@ -321,38 +330,48 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
m_bufferTensorDesc.DimensionCount = newDimensionCount;
}

// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout
void TensorDesc::SetDimensionsAndStrides(gsl::span<const uint32_t> sizes, gsl::span<const uint32_t> strides)
{
static_assert(sizeof(m_sizes) == sizeof(m_strides));
ML_CHECK_VALID_ARGUMENT(sizes.size() <= std::size(m_sizes));
ML_CHECK_VALID_ARGUMENT(strides.empty() || strides.size() == sizes.size());

std::copy(sizes.begin(), sizes.end(), m_sizes);
m_bufferTensorDesc.DimensionCount = static_cast<uint32_t>(sizes.size());
SetStrides(strides);
}

void TensorDesc::PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment)
{
const uint32_t oldRank = m_bufferTensorDesc.DimensionCount;
EnsureStridesExist();
SetDimensionCount(static_cast<uint32_t>(dimensionMapping.size()), alignment);

// Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping
std::vector<uint32_t> tempSizes{m_sizes, m_sizes + MaximumDimensionCount};
std::vector<uint32_t> tempStrides{m_strides, m_strides + MaximumDimensionCount};
// Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping.
// Note using MaximumDimensionCount instead of oldRank is intentional here, because the old rank could
// be smaller or larger than the new rank, but it will never be larger than MaximumDimensionCount.
std::vector<uint32_t> oldSizes{m_sizes, m_sizes + MaximumDimensionCount};
std::vector<uint32_t> oldStrides{m_strides, m_strides + MaximumDimensionCount};

for (size_t i = 0; i < dimensionMapping.size(); i++)
{
m_sizes[i] = tempSizes[dimensionMapping[i]];
m_strides[i] = tempStrides[dimensionMapping[i]];
uint32_t sourceAxis = dimensionMapping[i];
m_sizes[i] = sourceAxis < oldRank ? oldSizes[sourceAxis] : 1;
m_strides[i] = sourceAxis < oldRank ? oldStrides[sourceAxis] : 0;
}

m_bufferTensorDesc.Sizes = m_sizes;
m_bufferTensorDesc.Strides = m_strides;
}

void TensorDesc::EnsureStridesExist()
void TensorDesc::EnsureStridesExist() noexcept
{
if (m_bufferTensorDesc.Strides != nullptr)
{
// Strides are populated
// Strides are already populated
return;
}

uint32_t stride = 1;
for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;)
{
m_strides[i] = stride;
stride *= m_sizes[i];
}
fdwr marked this conversation as resolved.
Show resolved Hide resolved
GetDescendingPackedStrides({m_sizes, m_bufferTensorDesc.DimensionCount}, {m_strides, m_bufferTensorDesc.DimensionCount});
m_bufferTensorDesc.Strides = m_strides;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,28 @@ namespace Dml
uint32_t guaranteedBaseOffsetAlignment
);

DML_TENSOR_DESC GetDmlDesc();
DML_TENSOR_DESC GetDmlDesc() noexcept;

inline DML_TENSOR_DATA_TYPE GetDmlDataType() const { return m_bufferTensorDesc.DataType; }
inline MLOperatorTensorDataType GetMlOperatorDataType() const { return m_mlOperatorTensorDataType; }
inline DML_TENSOR_DATA_TYPE GetDmlDataType() const noexcept { return m_bufferTensorDesc.DataType; }
inline MLOperatorTensorDataType GetMlOperatorDataType() const noexcept { return m_mlOperatorTensorDataType; }
void ForceUnsignedDataType();

inline bool IsValid() const { return m_tensorType != DML_TENSOR_TYPE_INVALID; }
inline bool IsValid() const noexcept { return m_tensorType != DML_TENSOR_TYPE_INVALID; }
inline uint32_t GetDimensionCount() const { return m_bufferTensorDesc.DimensionCount; }
void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
gsl::span<const uint32_t> GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
gsl::span<const uint32_t> GetStrides() const;
void EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);

gsl::span<const uint32_t> GetSizes() const noexcept { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
gsl::span<const uint32_t> GetStrides() const noexcept;
void SetStrides(gsl::span<const uint32_t> strides);
void EnsureStridesExist() noexcept;

void SetDimensionsAndStrides(gsl::span<const uint32_t> sizes, gsl::span<const uint32_t> strides);

// Rearranges existing m_sizes and m_strides by gathering axes from dimensionMapping.
// It IS legal to change the number of dimensions by adding filler, dropping entire dimensions for a new view,
// and even duplicating logical dimensions. Axes beyond the original rank will be filled by size 1 and stride 0.
// e.g. Existing sizes [2,3,4] with [2,0] yields [4,2].
void PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment);

inline uint64_t GetBufferSizeInBytes() const
Expand Down Expand Up @@ -91,8 +101,6 @@ namespace Dml
uint32_t m_sizes[MaximumDimensionCount] = {};
uint32_t m_strides[MaximumDimensionCount] = {};
DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {};

void EnsureStridesExist();
};

class TensorDescBuilder
Expand Down
Loading
Loading