From 234b6dd613f137cf630d1ada89599aad763ab2ce Mon Sep 17 00:00:00 2001 From: Schell Carl Scivally Date: Tue, 10 Dec 2024 03:56:51 +1300 Subject: [PATCH] feature: [spv-front] Support for OpAtomicCompareExchange (#6590) Add support for parsing and executing OpAtomicCompareExchange in the SPIR-V frontend. This concludes the work to support atomics in the SPIR-V frontend, excluding test clean-up. Fixes #6296. Fixes #6590. Connections: - [naga spv-in] Support for OpAtomicCompareExchange #6296 - [spv-in] Atomics support #4489 Co-authored-by: Jim Blandy --- CHANGELOG.md | 1 + naga/src/front/atomic_upgrade.rs | 2 + naga/src/front/spv/mod.rs | 119 ++++++++++++++++-- naga/tests/in/spv/atomic_compare_exchange.spv | Bin 0 -> 1324 bytes .../in/spv/atomic_compare_exchange.spvasm | 89 +++++++++++++ 5 files changed, 201 insertions(+), 10 deletions(-) create mode 100644 naga/tests/in/spv/atomic_compare_exchange.spv create mode 100644 naga/tests/in/spv/atomic_compare_exchange.spvasm diff --git a/CHANGELOG.md b/CHANGELOG.md index 27043042a0..036e55ef09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -114,6 +114,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148] - Add support for GLSL `usampler*` and `isampler*`. By @DavidPeicho in [#6513](https://github.com/gfx-rs/wgpu/pull/6513). - Expose Ray Query flags as constants in WGSL. Implement candidate intersections. By @kvark in [#5429](https://github.com/gfx-rs/wgpu/pull/5429) - Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635). +- Add support for OpAtomicCompareExchange in SPIR-V frontend. By @schell in [#6590](https://github.com/gfx-rs/wgpu/pull/6590). #### General diff --git a/naga/src/front/atomic_upgrade.rs b/naga/src/front/atomic_upgrade.rs index 3edb6acea4..171df33169 100644 --- a/naga/src/front/atomic_upgrade.rs +++ b/naga/src/front/atomic_upgrade.rs @@ -46,6 +46,8 @@ pub enum Error { GlobalInitUnsupported, #[error("expected to find a global variable")] GlobalVariableMissing, + #[error("atomic compare exchange requires a scalar base type")] + CompareExchangeNonScalarBaseType, } #[derive(Clone, Default)] diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 52d7ea9d15..63406d3220 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4273,6 +4273,102 @@ impl> Frontend { self.upgrade_atomics .insert(ctx.get_contained_global_variable(p_exp_h)?); } + Op::AtomicCompareExchange => { + inst.expect(9)?; + + let start = self.data_offset; + let span = self.span_from_with_op(start); + let result_type_id = self.next()?; + let result_id = self.next()?; + let pointer_id = self.next()?; + let _memory_scope_id = self.next()?; + let _equal_memory_semantics_id = self.next()?; + let _unequal_memory_semantics_id = self.next()?; + let value_id = self.next()?; + let comparator_id = self.next()?; + + let (p_exp_h, p_base_ty_h) = self.get_exp_and_base_ty_handles( + pointer_id, + ctx, + &mut emitter, + &mut block, + body_idx, + )?; + + log::trace!("\t\t\tlooking up value expr {:?}", value_id); + let v_lexp_handle = + get_expr_handle!(value_id, self.lookup_expression.lookup(value_id)?); + + log::trace!("\t\t\tlooking up comparator expr {:?}", value_id); + let c_lexp_handle = get_expr_handle!( + comparator_id, + self.lookup_expression.lookup(comparator_id)? + ); + + // We know from the SPIR-V spec that the result type must be an integer + // scalar, and we'll need the type itself to get a handle to the atomic + // result struct. + let crate::TypeInner::Scalar(scalar) = ctx.module.types[p_base_ty_h].inner + else { + return Err( + crate::front::atomic_upgrade::Error::CompareExchangeNonScalarBaseType + .into(), + ); + }; + + // Get a handle to the atomic result struct type. + let atomic_result_struct_ty_h = ctx.module.generate_predeclared_type( + crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar), + ); + + block.extend(emitter.finish(ctx.expressions)); + + // Create an expression for our atomic result + let atomic_lexp_handle = { + let expr = crate::Expression::AtomicResult { + ty: atomic_result_struct_ty_h, + comparison: true, + }; + ctx.expressions.append(expr, span) + }; + + // Create an dot accessor to extract the value from the + // result struct __atomic_compare_exchange_result and use that + // as the expression for the result_id + { + let expr = crate::Expression::AccessIndex { + base: atomic_lexp_handle, + index: 0, + }; + let handle = ctx.expressions.append(expr, span); + // Use this dot accessor as the result id's expression + let _ = self.lookup_expression.insert( + result_id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + + emitter.start(ctx.expressions); + + // Create a statement for the op itself + let stmt = crate::Statement::Atomic { + pointer: p_exp_h, + fun: crate::AtomicFunction::Exchange { + compare: Some(c_lexp_handle), + }, + value: v_lexp_handle, + result: Some(atomic_lexp_handle), + }; + block.push(stmt, span); + + // Store any associated global variables so we can upgrade their types later + self.upgrade_atomics + .insert(ctx.get_contained_global_variable(p_exp_h)?); + } Op::AtomicExchange | Op::AtomicIAdd | Op::AtomicISub @@ -5969,17 +6065,18 @@ mod test_atomic { let m = crate::front::spv::parse_u8_slice(bytes, &Default::default()).unwrap(); let mut wgsl = String::new(); - let mut should_panic = false; - for vflags in [ - crate::valid::ValidationFlags::all(), - crate::valid::ValidationFlags::empty(), + for (vflags, name) in [ + (crate::valid::ValidationFlags::empty(), "empty"), + (crate::valid::ValidationFlags::all(), "all"), ] { + log::info!("validating with flags - {name}"); let mut validator = crate::valid::Validator::new(vflags, Default::default()); match validator.validate(&m) { Err(e) => { log::error!("SPIR-V validation {}", e.emit_to_string("")); - should_panic = true; + log::info!("types: {:#?}", m.types); + panic!("validation error"); } Ok(i) => { wgsl = crate::back::wgsl::write_string( @@ -5989,15 +6086,10 @@ mod test_atomic { ) .unwrap(); log::info!("wgsl-out:\n{wgsl}"); - break; } }; } - if should_panic { - panic!("validation error"); - } - let m = match crate::front::wgsl::parse_str(&wgsl) { Ok(m) => m, Err(e) => { @@ -6032,6 +6124,13 @@ mod test_atomic { atomic_test(include_bytes!("../../../tests/in/spv/atomic_exchange.spv")); } + #[test] + fn atomic_compare_exchange() { + atomic_test(include_bytes!( + "../../../tests/in/spv/atomic_compare_exchange.spv" + )); + } + #[test] fn atomic_i_decrement() { atomic_test(include_bytes!( diff --git a/naga/tests/in/spv/atomic_compare_exchange.spv b/naga/tests/in/spv/atomic_compare_exchange.spv new file mode 100644 index 0000000000000000000000000000000000000000..e1eab2225187550b3231919010fa70af86b529ad GIT binary patch literal 1324 zcmYk5OHWf#5XYw!B!E&7@PS%ytuG*oq5{?!(G4!uFVHm24Z1)Kbm78gUc6V?0OJ_fC=U=-YJD+!QDnyYz6X(jB5m}jUjki>rKuXQN7-0Iv zM&WI$Od_pkj=Gd`wu#c$C{W}3_N{?kE#|m_tuhX%TF_Vccq@Gt7!E_Fna zdk1N*@;&wKcNaEaO4vPgu?l=locp^E8>{b>sh5z_U;5}nfOW#(P$KSkJmC_Yu1L@b0DIIOEe!+=G3b#XiFjIh z`lWnUukyLg6_{&64T>xC7ufi!0m_8j8QYrMApVcWQ$Y-kGk9QMul zzxd8%=kT2$m3PuYH%{!3-Fi=3u(!P=ZWBHOk-||n literal 0 HcmV?d00001 diff --git a/naga/tests/in/spv/atomic_compare_exchange.spvasm b/naga/tests/in/spv/atomic_compare_exchange.spvasm new file mode 100644 index 0000000000..70cf4f11c1 --- /dev/null +++ b/naga/tests/in/spv/atomic_compare_exchange.spvasm @@ -0,0 +1,89 @@ +; SPIR-V +; Version: 1.5 +; Generator: Google rspirv; 0 +; Bound: 65 +; Schema: 0 + OpCapability Shader + OpCapability VulkanMemoryModel + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %1 "stage::test_atomic_compare_exchange" %2 %3 + OpExecutionMode %1 LocalSize 32 1 1 + OpMemberDecorate %_struct_9 0 Offset 0 + OpMemberDecorate %_struct_9 1 Offset 4 + OpDecorate %_struct_10 Block + OpMemberDecorate %_struct_10 0 Offset 0 + OpDecorate %2 Binding 0 + OpDecorate %2 DescriptorSet 0 + OpDecorate %3 NonWritable + OpDecorate %3 Binding 1 + OpDecorate %3 DescriptorSet 0 + %uint = OpTypeInt 32 0 + %void = OpTypeVoid + %13 = OpTypeFunction %void + %bool = OpTypeBool + %uint_0 = OpConstant %uint 0 + %uint_2 = OpConstant %uint 2 + %false = OpConstantFalse %bool +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint + %uint_1 = OpConstant %uint 1 + %_struct_9 = OpTypeStruct %uint %uint + %20 = OpUndef %_struct_9 + %uint_3 = OpConstant %uint 3 + %int = OpTypeInt 32 1 + %23 = OpUndef %bool + %true = OpConstantTrue %bool + %_struct_10 = OpTypeStruct %uint +%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10 + %2 = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer + %3 = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer + %uint_256 = OpConstant %uint 256 + %1 = OpFunction %void None %13 + %27 = OpLabel + %28 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %2 %uint_0 + %29 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %3 %uint_0 + %30 = OpLoad %uint %29 + %31 = OpCompositeConstruct %_struct_9 %uint_0 %30 + OpBranch %32 + %32 = OpLabel + %33 = OpPhi %_struct_9 %31 %27 %34 %35 + OpLoopMerge %36 %35 None + OpBranch %37 + %37 = OpLabel + %38 = OpCompositeExtract %uint %33 0 + %39 = OpCompositeExtract %uint %33 1 + %40 = OpULessThan %bool %38 %39 + OpSelectionMerge %41 None + OpBranchConditional %40 %42 %43 + %42 = OpLabel + %45 = OpIAdd %uint %38 %uint_1 + %46 = OpCompositeInsert %_struct_9 %45 %33 0 + %47 = OpCompositeConstruct %_struct_9 %uint_1 %38 + OpBranch %41 + %43 = OpLabel + %48 = OpCompositeInsert %_struct_9 %uint_0 %20 0 + OpBranch %41 + %41 = OpLabel + %34 = OpPhi %_struct_9 %46 %42 %33 %43 + %49 = OpPhi %_struct_9 %47 %42 %48 %43 + %50 = OpCompositeExtract %uint %49 0 + %51 = OpCompositeExtract %uint %49 1 + %52 = OpBitcast %int %50 + OpSelectionMerge %53 None + OpSwitch %52 %54 0 %55 1 %56 + %54 = OpLabel + OpBranch %53 + %55 = OpLabel + OpBranch %53 + %56 = OpLabel + %57 = OpAtomicCompareExchange %uint %28 %uint_2 %uint_256 %uint_256 %51 %uint_3 + %58 = OpIEqual %bool %57 %uint_3 + %64 = OpSelect %bool %58 %false %true + OpBranch %53 + %53 = OpLabel + %63 = OpPhi %bool %23 %54 %false %55 %64 %56 + OpBranch %35 + %35 = OpLabel + OpBranchConditional %63 %32 %36 + %36 = OpLabel + OpReturn + OpFunctionEnd