Skip to content

Commit

Permalink
feat: Add 32-bit float atomics support for Vulkan (SPIR-V shaders)
Browse files Browse the repository at this point in the history
* atomicSub for f32 in the previous commits is removed.
  • Loading branch information
AsherJingkongChen authored and ArthurBrussee committed Sep 16, 2024
1 parent 3bba7d8 commit 144a910
Show file tree
Hide file tree
Showing 13 changed files with 269 additions and 76 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Bottom level categories:

#### Metal

- Support 32-bit floating-point atomic operations in shaders. It requires Metal 3.0 or later with Apple 7, 8, 9 or Mac 2. By @AsherJingkongChen in [#6234](https://github.com/gfx-rs/wgpu/pull/6234).
- Support some 32-bit floating-point atomic operations in shaders. It requires Metal 3.0 or later with Apple 7, 8, 9 or Mac 2. By @AsherJingkongChen in [#6234](https://github.com/gfx-rs/wgpu/pull/6234)

### Bug Fixes

Expand Down
139 changes: 94 additions & 45 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2455,51 +2455,100 @@ impl<'w> BlockContext<'w> {
let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);

let instruction = match *fun {
crate::AtomicFunction::Add => Instruction::atomic_binary(
spirv::Op::AtomicIAdd,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::Subtract => Instruction::atomic_binary(
spirv::Op::AtomicISub,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::And => Instruction::atomic_binary(
spirv::Op::AtomicAnd,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary(
spirv::Op::AtomicOr,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary(
spirv::Op::AtomicXor,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::Add => {
let spirv_op = match *value_inner {
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _,
}) => spirv::Op::AtomicIAdd,
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Float,
width: _,
}) => spirv::Op::AtomicFAddEXT,
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::Subtract => {
let spirv_op = match *value_inner {
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _,
}) => spirv::Op::AtomicISub,
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::And => {
let spirv_op = match *value_inner {
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _,
}) => spirv::Op::AtomicAnd,
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::InclusiveOr => {
let spirv_op = match *value_inner {
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _,
}) => spirv::Op::AtomicOr,
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::ExclusiveOr => {
let spirv_op = match *value_inner {
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _,
}) => spirv::Op::AtomicXor,
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::Min => {
let spirv_op = match *value_inner {
crate::TypeInner::Scalar(crate::Scalar {
Expand Down
10 changes: 10 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,16 @@ impl Writer {
crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
}
crate::TypeInner::Atomic(crate::Scalar {
width: 4,
kind: crate::ScalarKind::Float,
}) => {
self.require_any(
"32 bit floating-point atomics",
&[spirv::Capability::AtomicFloat32AddEXT],
)?;
self.use_extension("SPV_EXT_shader_atomic_float_add");
}
_ => {}
}
Ok(())
Expand Down
3 changes: 1 addition & 2 deletions naga/src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,11 @@ impl super::Validator {
// that that `Atomic` type has a permitted scalar width.
if let crate::ScalarKind::Float = pointer_scalar.kind {
// `Capabilities::SHADER_FLT32_ATOMIC` enables 32-bit floating-point
// atomic operations including `Add`, `Subtract`, and `Exchange`
// atomic operations including `Add` and `Exchange`
// in storage address space.
if !matches!(
*fun,
crate::AtomicFunction::Add
| crate::AtomicFunction::Subtract
| crate::AtomicFunction::Exchange { compare: _ }
) {
log::error!("Float32 atomic operation {:?} is not supported", fun);
Expand Down
2 changes: 1 addition & 1 deletion naga/src/valid/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ impl super::Validator {
if width == 4 {
if !self
.capabilities
.intersects(Capabilities::SHADER_FLT32_ATOMIC)
.contains(Capabilities::SHADER_FLT32_ATOMIC)
{
return Err(TypeError::MissingCapability(
Capabilities::SHADER_FLT32_ATOMIC,
Expand Down
4 changes: 4 additions & 0 deletions naga/tests/in/atomicOps-flt32.param.ron
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
(
god_mode: true,
spv: (
version: (1, 1),
capabilities: [ AtomicFloat32AddEXT ],
),
msl: (
lang_version: (3, 0),
per_entry_point_map: {},
Expand Down
7 changes: 0 additions & 7 deletions naga/tests/in/atomicOps-flt32.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ fn cs_main(@builtin(local_invocation_id) id: vec3<u32>) {

workgroupBarrier();

atomicSub(&storage_atomic_scalar, 1.0);
atomicSub(&storage_atomic_arr[1], 1.0);
atomicSub(&storage_struct.atomic_scalar, 1.0);
atomicSub(&storage_struct.atomic_arr[1], 1.0);

workgroupBarrier();

atomicExchange(&storage_atomic_scalar, 1.0);
atomicExchange(&storage_atomic_arr[1], 1.0);
atomicExchange(&storage_struct.atomic_scalar, 1.0);
Expand Down
13 changes: 4 additions & 9 deletions naga/tests/out/msl/atomicOps-flt32.msl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,9 @@ kernel void cs_main(
float _e35 = metal::atomic_fetch_add_explicit(&storage_struct.atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e40 = metal::atomic_fetch_add_explicit(&storage_struct.atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
float _e43 = metal::atomic_fetch_sub_explicit(&storage_atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e47 = metal::atomic_fetch_sub_explicit(&storage_atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
float _e51 = metal::atomic_fetch_sub_explicit(&storage_struct.atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e56 = metal::atomic_fetch_sub_explicit(&storage_struct.atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
float _e59 = metal::atomic_exchange_explicit(&storage_atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e63 = metal::atomic_exchange_explicit(&storage_atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
float _e67 = metal::atomic_exchange_explicit(&storage_struct.atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e72 = metal::atomic_exchange_explicit(&storage_struct.atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
float _e43 = metal::atomic_exchange_explicit(&storage_atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e47 = metal::atomic_exchange_explicit(&storage_atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
float _e51 = metal::atomic_exchange_explicit(&storage_struct.atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e56 = metal::atomic_exchange_explicit(&storage_struct.atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
return;
}
98 changes: 98 additions & 0 deletions naga/tests/out/spv/atomicOps-flt32.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 62
OpCapability Shader
OpCapability AtomicFloat32AddEXT
OpExtension "SPV_KHR_storage_buffer_storage_class"
OpExtension "SPV_EXT_shader_atomic_float_add"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %22 "cs_main" %19
OpExecutionMode %22 LocalSize 2 1 1
OpDecorate %4 ArrayStride 4
OpMemberDecorate %7 0 Offset 0
OpMemberDecorate %7 1 Offset 4
OpDecorate %9 DescriptorSet 0
OpDecorate %9 Binding 0
OpDecorate %10 Block
OpMemberDecorate %10 0 Offset 0
OpDecorate %12 DescriptorSet 0
OpDecorate %12 Binding 1
OpDecorate %13 Block
OpMemberDecorate %13 0 Offset 0
OpDecorate %15 DescriptorSet 0
OpDecorate %15 Binding 2
OpDecorate %16 Block
OpMemberDecorate %16 0 Offset 0
OpDecorate %19 BuiltIn LocalInvocationId
%2 = OpTypeVoid
%3 = OpTypeFloat 32
%6 = OpTypeInt 32 0
%5 = OpConstant %6 2
%4 = OpTypeArray %3 %5
%7 = OpTypeStruct %3 %4
%8 = OpTypeVector %6 3
%10 = OpTypeStruct %3
%11 = OpTypePointer StorageBuffer %10
%9 = OpVariable %11 StorageBuffer
%13 = OpTypeStruct %4
%14 = OpTypePointer StorageBuffer %13
%12 = OpVariable %14 StorageBuffer
%16 = OpTypeStruct %7
%17 = OpTypePointer StorageBuffer %16
%15 = OpVariable %17 StorageBuffer
%20 = OpTypePointer Input %8
%19 = OpVariable %20 Input
%23 = OpTypeFunction %2
%24 = OpTypePointer StorageBuffer %3
%25 = OpConstant %6 0
%27 = OpTypePointer StorageBuffer %4
%29 = OpTypePointer StorageBuffer %7
%31 = OpConstant %3 1.0
%34 = OpTypeInt 32 1
%33 = OpConstant %34 1
%35 = OpConstant %6 64
%36 = OpConstant %6 1
%40 = OpConstant %6 264
%22 = OpFunction %2 None %23
%18 = OpLabel
%21 = OpLoad %8 %19
%26 = OpAccessChain %24 %9 %25
%28 = OpAccessChain %27 %12 %25
%30 = OpAccessChain %29 %15 %25
OpBranch %32
%32 = OpLabel
OpAtomicStore %26 %33 %35 %31
%37 = OpAccessChain %24 %28 %36
OpAtomicStore %37 %33 %35 %31
%38 = OpAccessChain %24 %30 %25
OpAtomicStore %38 %33 %35 %31
%39 = OpAccessChain %24 %30 %36 %36
OpAtomicStore %39 %33 %35 %31
OpControlBarrier %5 %5 %40
%41 = OpAtomicLoad %3 %26 %33 %35
%42 = OpAccessChain %24 %28 %36
%43 = OpAtomicLoad %3 %42 %33 %35
%44 = OpAccessChain %24 %30 %25
%45 = OpAtomicLoad %3 %44 %33 %35
%46 = OpAccessChain %24 %30 %36 %36
%47 = OpAtomicLoad %3 %46 %33 %35
OpControlBarrier %5 %5 %40
%48 = OpAtomicFAddEXT %3 %26 %33 %35 %31
%50 = OpAccessChain %24 %28 %36
%49 = OpAtomicFAddEXT %3 %50 %33 %35 %31
%52 = OpAccessChain %24 %30 %25
%51 = OpAtomicFAddEXT %3 %52 %33 %35 %31
%54 = OpAccessChain %24 %30 %36 %36
%53 = OpAtomicFAddEXT %3 %54 %33 %35 %31
OpControlBarrier %5 %5 %40
%55 = OpAtomicExchange %3 %26 %33 %35 %31
%57 = OpAccessChain %24 %28 %36
%56 = OpAtomicExchange %3 %57 %33 %35 %31
%59 = OpAccessChain %24 %30 %25
%58 = OpAtomicExchange %3 %59 %33 %35 %31
%61 = OpAccessChain %24 %30 %36 %36
%60 = OpAtomicExchange %3 %61 %33 %35 %31
OpReturn
OpFunctionEnd
13 changes: 4 additions & 9 deletions naga/tests/out/wgsl/atomicOps-flt32.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,9 @@ fn cs_main(@builtin(local_invocation_id) id: vec3<u32>) {
let _e35 = atomicAdd((&storage_struct.atomic_scalar), 1f);
let _e40 = atomicAdd((&storage_struct.atomic_arr[1]), 1f);
workgroupBarrier();
let _e43 = atomicSub((&storage_atomic_scalar), 1f);
let _e47 = atomicSub((&storage_atomic_arr[1]), 1f);
let _e51 = atomicSub((&storage_struct.atomic_scalar), 1f);
let _e56 = atomicSub((&storage_struct.atomic_arr[1]), 1f);
workgroupBarrier();
let _e59 = atomicExchange((&storage_atomic_scalar), 1f);
let _e63 = atomicExchange((&storage_atomic_arr[1]), 1f);
let _e67 = atomicExchange((&storage_struct.atomic_scalar), 1f);
let _e72 = atomicExchange((&storage_struct.atomic_arr[1]), 1f);
let _e43 = atomicExchange((&storage_atomic_scalar), 1f);
let _e47 = atomicExchange((&storage_atomic_arr[1]), 1f);
let _e51 = atomicExchange((&storage_struct.atomic_scalar), 1f);
let _e56 = atomicExchange((&storage_struct.atomic_arr[1]), 1f);
return;
}
5 changes: 4 additions & 1 deletion naga/tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,10 @@ fn convert_wgsl() {
"atomicOps-int64-min-max",
Targets::SPIRV | Targets::METAL | Targets::HLSL | Targets::WGSL,
),
("atomicOps-flt32", Targets::METAL | Targets::WGSL),
(
"atomicOps-flt32",
Targets::SPIRV | Targets::METAL | Targets::WGSL,
),
(
"atomicCompareExchange-int64",
Targets::SPIRV | Targets::WGSL,
Expand Down
Loading

0 comments on commit 144a910

Please sign in to comment.