Skip to content

Commit

Permalink
[SPIR-V] Emulate OpBitFieldInsert for type != i32 (microsoft#6491)
Browse files Browse the repository at this point in the history
SPIR-V supported OpBitFieldInsert on all integer types. But Vulkan
requires the operands to be 32bit integers.
This means we need to emulate this instruction for types other then i32.

This PR only adds emulation for OpBitFieldInsert. The next PR will add
support for emulating OpBitFieldExtract,

Related to microsoft#6327

---------

Signed-off-by: Nathan Gauër <[email protected]>
  • Loading branch information
Keenuts authored Apr 9, 2024
1 parent 8b5b6c6 commit 0781ded
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 38 deletions.
17 changes: 13 additions & 4 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,11 @@ class SpirvBuilder {

/// \brief Creates an OpBitFieldInsert SPIR-V instruction for the given
/// arguments.
SpirvBitFieldInsert *
createBitFieldInsert(QualType resultType, SpirvInstruction *base,
SpirvInstruction *insert, SpirvInstruction *offset,
SpirvInstruction *count, SourceLocation);
SpirvInstruction *createBitFieldInsert(QualType resultType,
SpirvInstruction *base,
SpirvInstruction *insert,
unsigned bitOffset, unsigned bitCount,
SourceLocation, SourceRange);

/// \brief Creates an OpBitFieldUExtract or OpBitFieldSExtract SPIR-V
/// instruction for the given arguments.
Expand Down Expand Up @@ -831,6 +832,14 @@ class SpirvBuilder {
const SpirvType *spvType,
SpirvInstruction *var);

/// \brief Emulates OpBitFieldInsert SPIR-V instruction for the given
/// arguments.
SpirvInstruction *
createEmulatedBitFieldInsert(QualType resultType, uint32_t baseTypeBitwidth,
SpirvInstruction *base, SpirvInstruction *insert,
unsigned bitOffset, unsigned bitCount,
SourceLocation, SourceRange);

private:
ASTContext &astContext;
SpirvContext &context; ///< From which we allocate various SPIR-V object
Expand Down
9 changes: 2 additions & 7 deletions tools/clang/lib/SPIRV/InitListHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,9 @@ InitListHandler::createInitForStructType(QualType type, SourceLocation srcLoc,
// For the remaining bitfields, we need to insert them into the existing
// container, which is the last element in `fields`.
assert(fields.size() == fieldInfo.fieldIndex + 1);
SpirvInstruction *offset = spvBuilder.getConstantInt(
astContext.UnsignedIntTy,
llvm::APInt(32, fieldInfo.bitfield->offsetInBits));
SpirvInstruction *count = spvBuilder.getConstantInt(
astContext.UnsignedIntTy,
llvm::APInt(32, fieldInfo.bitfield->sizeInBits));
fields.back() = spvBuilder.createBitFieldInsert(
fieldType, fields.back(), init, offset, count, srcLoc);
fieldType, fields.back(), init, fieldInfo.bitfield->offsetInBits,
fieldInfo.bitfield->sizeInBits, srcLoc, range);
return true;
},
true);
Expand Down
117 changes: 102 additions & 15 deletions tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,9 @@ SpirvStore *SpirvBuilder::createStore(SpirvInstruction *address,
context.addToInstructionsWithLoweredType(value);

auto *base = createLoad(value->getResultType(), address, loc, range);
auto *offset = getConstantInt(
astContext.UnsignedIntTy,
llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->offsetInBits),
false));
auto *count = getConstantInt(
astContext.UnsignedIntTy,
llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->sizeInBits),
false));
source =
createBitFieldInsert(/*QualType*/ {}, base, value, offset, count, loc);
source = createBitFieldInsert(/*QualType*/ {}, base, value,
bitfieldInfo->offsetInBits,
bitfieldInfo->sizeInBits, loc, range);
source->setResultType(value->getResultType());
}

Expand Down Expand Up @@ -882,12 +875,106 @@ void SpirvBuilder::createBarrier(spv::Scope memoryScope,
insertPoint->addInstruction(barrier);
}

SpirvBitFieldInsert *SpirvBuilder::createBitFieldInsert(
QualType resultType, SpirvInstruction *base, SpirvInstruction *insert,
SpirvInstruction *offset, SpirvInstruction *count, SourceLocation loc) {
SpirvInstruction *SpirvBuilder::createEmulatedBitFieldInsert(
QualType resultType, uint32_t baseTypeBitwidth, SpirvInstruction *base,
SpirvInstruction *insert, unsigned bitOffset, unsigned bitCount,
SourceLocation loc, SourceRange range) {

// The destination is a raw struct field, which can contain several bitfields:
// raw field: AAAABBBBCCCCCCCCDDDD
// To insert a new value for the field BBBB, we need to clear the B bits in
// the field, and insert the new values.

// Create a mask to clear B from the raw field.
// mask = (1 << bitCount) - 1
// raw field: AAAABBBBCCCCCCCCDDDD
// mask: 00000000000000001111
// cast mask to the an unsigned with the same bitwidth.
// mask = (unsigned dstType)mask
// Move the mask to B's position in the raw type.
// mask = mask << bitOffset
// raw field: AAAABBBBCCCCCCCCDDDD
// mask: 00001111000000000000
// Generate inverted mask to clear other bits in *insert*.
// notMask = ~mask
// raw field: AAAABBBBCCCCCCCCDDDD
// mask: 11110000111111111111
assert(bitCount <= 64 &&
"Bitfield insertion emulation can only insert at most 64 bits.");
auto maskTy =
astContext.getIntTypeForBitwidth(baseTypeBitwidth, /* signed= */ 0);
const uint64_t maskValue = ((1ull << bitCount) - 1ull) << bitOffset;
const uint64_t notMaskValue = ~maskValue;

auto *mask = getConstantInt(maskTy, llvm::APInt(baseTypeBitwidth, maskValue));
auto *notMask =
getConstantInt(maskTy, llvm::APInt(baseTypeBitwidth, notMaskValue));
auto *shiftOffset =
getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, bitOffset));

// base = base & MASK // Clear bits at B's position.
// input: AAAABBBBCCCCCCCCDDDD
// output: AAAA----CCCCCCCCDDDD
auto *clearedDst = createBinaryOp(spv::Op::OpBitwiseAnd, resultType, base,
notMask, loc, range);

// input: SSSSSSSSSSSSSSSSBBBB
// tmp = (dstType)SRC // Convert SRC to the base type.
// tmp = tmp << bitOffset // Move the SRC value to the correct bit offset.
// output: SSSSBBBB------------
// tmp = tmp & ~MASK // Clear any sign extension bits.
// output: ----BBBB------------
auto *castedSrc =
createUnaryOp(spv::Op::OpBitcast, resultType, insert, loc, range);
auto *shiftedSrc = createBinaryOp(spv::Op::OpShiftLeftLogical, resultType,
castedSrc, shiftOffset, loc, range);
auto *maskedSrc = createBinaryOp(spv::Op::OpBitwiseAnd, resultType,
shiftedSrc, mask, loc, range);

// base = base | tmp; // Insert B in the raw field.
// tmp: ----BBBB------------
// base: AAAA----CCCCCCCCDDDD
// output: AAAABBBBCCCCCCCCDDDD
auto *result = createBinaryOp(spv::Op::OpBitwiseOr, resultType, clearedDst,
maskedSrc, loc, range);

if (base->getResultType()) {
auto *dstTy = dyn_cast<IntegerType>(base->getResultType());
clearedDst->setResultType(dstTy);
shiftedSrc->setResultType(dstTy);
maskedSrc->setResultType(dstTy);
castedSrc->setResultType(dstTy);
result->setResultType(dstTy);
}
return result;
}

SpirvInstruction *
SpirvBuilder::createBitFieldInsert(QualType resultType, SpirvInstruction *base,
SpirvInstruction *insert, unsigned bitOffset,
unsigned bitCount, SourceLocation loc,
SourceRange range) {
assert(insertPoint && "null insert point");
auto *inst = new (context)
SpirvBitFieldInsert(resultType, loc, base, insert, offset, count);

uint32_t bitwidth = 0;
if (resultType == QualType({})) {
assert(base->hasResultType() && "No type information for bitfield.");
bitwidth = dyn_cast<IntegerType>(base->getResultType())->getBitwidth();
} else {
bitwidth = getElementSpirvBitwidth(astContext, resultType,
spirvOptions.enable16BitTypes);
}

if (bitwidth != 32)
return createEmulatedBitFieldInsert(resultType, bitwidth, base, insert,
bitOffset, bitCount, loc, range);

auto *insertOffset =
getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, bitOffset));
auto *insertCount =
getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, bitCount));
auto *inst = new (context) SpirvBitFieldInsert(resultType, loc, base, insert,
insertOffset, insertCount);
insertPoint->addInstruction(inst);
inst->setRValue(true);
return inst;
Expand Down
22 changes: 10 additions & 12 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9237,6 +9237,7 @@ SpirvEmitter::processIntrinsicNonUniformResourceIndex(const CallExpr *expr) {
SpirvInstruction *
SpirvEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
const auto loc = callExpr->getExprLoc();
const auto range = callExpr->getSourceRange();
if (!spirvOptions.noWarnEmulatedFeatures)
emitWarning("msad4 intrinsic function is emulated using many SPIR-V "
"instructions due to lack of direct SPIR-V equivalent",
Expand Down Expand Up @@ -9297,18 +9298,15 @@ SpirvEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
// Do bfi 3 times. DXIL bfi is equivalent to SPIR-V OpBitFieldInsert.
auto *v1y = spvBuilder.createCompositeExtract(uintType, source, {1}, loc);
// Note that t0.x = v1.x, nothing we need to do for that.
auto *t0y =
spvBuilder.createBitFieldInsert(uintType, /*base*/ v1xS8, /*insert*/ v1y,
/*offset*/ uint24,
/*width*/ uint8, loc);
auto *t0z =
spvBuilder.createBitFieldInsert(uintType, /*base*/ v1xS16, /*insert*/ v1y,
/*offset*/ uint16,
/*width*/ uint16, loc);
auto *t0w =
spvBuilder.createBitFieldInsert(uintType, /*base*/ v1xS24, /*insert*/ v1y,
/*offset*/ uint8,
/*width*/ uint24, loc);
auto *t0y = spvBuilder.createBitFieldInsert(
uintType, /*base*/ v1xS8, /*insert*/ v1y,
/* bitOffest */ 24, /* bitCount */ 8, loc, range);
auto *t0z = spvBuilder.createBitFieldInsert(
uintType, /*base*/ v1xS16, /*insert*/ v1y,
/* bitOffest */ 16, /* bitCount */ 16, loc, range);
auto *t0w = spvBuilder.createBitFieldInsert(
uintType, /*base*/ v1xS24, /*insert*/ v1y,
/* bitOffest */ 8, /* bitCount */ 24, loc, range);

// Step 3. MSAD (Masked Sum of Absolute Differences)

Expand Down
148 changes: 148 additions & 0 deletions tools/clang/test/CodeGenSPIRV/op.struct.access.bitfield.sized.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// RUN: %dxc -T cs_6_2 -E main -spirv -fcgl -enable-16bit-types %s | FileCheck %s

struct S1 {
uint64_t f1 : 32;
uint64_t f2 : 1;
};
// CHECK-DAG: %S1 = OpTypeStruct %ulong

struct S2 {
uint16_t f1 : 4;
uint16_t f2 : 5;
};
// CHECK-DAG: %S2 = OpTypeStruct %ushort

struct S3 {
uint64_t f1 : 45;
uint64_t f2 : 10;
uint16_t f3 : 7;
uint32_t f4 : 5;
};
// CHECK-DAG: %S3 = OpTypeStruct %ulong %ushort %uint

struct S4 {
int64_t f1 : 32;
int64_t f2 : 1;
};
// CHECK-DAG: %S4 = OpTypeStruct %long

[numthreads(1, 1, 1)]
void main() {
S1 s1;
s1.f1 = 3;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s1 %int_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ulong [[ptr]]
// 0xffffffff00000000
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_18446744069414584320
// CHECK: [[val:%[0-9]+]] = OpBitcast %ulong %ulong_3
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[val]] %uint_0
// 0x00000000ffffffff
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_4294967295
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ulong [[dst]] [[src]]
// CHECK: OpStore [[ptr]] [[mix]]

s1.f2 = 1;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s1 %int_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ulong [[ptr]]
// 0xfffffffeffffffff
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_18446744069414584319
// CHECK: [[val:%[0-9]+]] = OpBitcast %ulong %ulong_1
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[val]] %uint_32
// 0x0000000100000000
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_4294967296
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ulong [[dst]] [[src]]
// CHECK: OpStore [[ptr]] [[mix]]

S2 s2;
s2.f1 = 2;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ushort %s2 %int_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort [[ptr]]
// 0xfff0
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_65520
// CHECK: [[val:%[0-9]+]] = OpBitcast %ushort %ushort_2
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ushort [[val]] %uint_0
// 0x000f
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_15
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ushort [[dst]] [[src]]
// CHECK: OpStore [[ptr]] [[mix]]

s2.f2 = 3;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ushort %s2 %int_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort [[ptr]]
// 0xfe0f
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_65039
// CHECK: [[val:%[0-9]+]] = OpBitcast %ushort %ushort_3
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ushort [[val]] %uint_4
// 0x01f0
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_496
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ushort [[dst]] [[src]]
// CHECK: OpStore [[ptr]] [[mix]]

S3 s3;
s3.f1 = 5;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s3 %int_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ulong [[ptr]]
// 0xffffe00000000000
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_18446708889337462784
// CHECK: [[val:%[0-9]+]] = OpBitcast %ulong %ulong_5
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[val]] %uint_0
// 0x00001fffffffffff
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_35184372088831
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ulong [[dst]] [[src]]
// CHECK: OpStore [[ptr]] [[mix]]

s3.f2 = 6;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s3 %int_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ulong [[ptr]]
// 0xff801fffffffffff
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_18410750461062676479
// CHECK: [[val:%[0-9]+]] = OpBitcast %ulong %ulong_6
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[val]] %uint_45
// 0x007fe00000000000
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_35993612646875136
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ulong [[dst]] [[src]]
// CHECK: OpStore [[ptr]] [[mix]]

s3.f3 = 7;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ushort %s3 %int_1
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort [[ptr]]
// 0xff80
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_65408
// CHECK: [[val:%[0-9]+]] = OpBitcast %ushort %ushort_7
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ushort [[val]] %uint_0
// 0x007f
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_127
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ushort [[dst]] [[src]]
// CHECK: OpStore [[ptr]] [[mix]]

s3.f4 = 8;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_uint %s3 %int_2
// CHECK: [[val:%[0-9]+]] = OpLoad %uint [[ptr]]
// CHECK: [[tmp:%[0-9]+]] = OpBitFieldInsert %uint [[val]] %uint_8 %uint_0 %uint_5
// CHECK: OpStore [[ptr]] [[tmp]]

S4 s4;
s4.f1 = 3;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_long %s4 %int_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %long [[ptr]]
// 0xffffffff00000000
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %long [[tmp]] %ulong_18446744069414584320
// CHECK: [[val:%[0-9]+]] = OpBitcast %long %long_3
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %long [[val]] %uint_0
// 0x00000000ffffffff
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %long [[tmp]] %ulong_4294967295
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %long [[dst]] [[src]]
// CHECK: OpStore [[ptr]] [[mix]]

s4.f2 = 1;
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_long %s4 %int_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %long [[ptr]]
// 0xfffffffeffffffff
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %long [[tmp]] %ulong_18446744069414584319
// CHECK: [[val:%[0-9]+]] = OpBitcast %long %long_1
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %long [[val]] %uint_32
// 0x0000000100000000
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %long [[tmp]] %ulong_4294967296
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %long [[dst]] [[src]]
// CHECK: OpStore [[ptr]] [[mix]]
}

0 comments on commit 0781ded

Please sign in to comment.