From 05ef1f8f6c7d352cd2a1272917ccb3caaef47f34 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sat, 30 Nov 2024 14:18:28 -0800 Subject: [PATCH 1/7] add workgroup_size_overrides --- naga/src/back/pipeline_constants.rs | 49 ++++++++++++++ naga/src/front/glsl/functions.rs | 1 + naga/src/front/spv/function.rs | 1 + naga/src/front/wgsl/lower/mod.rs | 66 +++++++++++++++++-- naga/src/lib.rs | 2 + naga/tests/out/ir/access.compact.ron | 4 ++ naga/tests/out/ir/access.ron | 4 ++ .../out/ir/atomic_i_increment.compact.ron | 1 + naga/tests/out/ir/atomic_i_increment.ron | 1 + naga/tests/out/ir/collatz.compact.ron | 1 + naga/tests/out/ir/collatz.ron | 1 + naga/tests/out/ir/fetch_depth.compact.ron | 1 + naga/tests/out/ir/fetch_depth.ron | 1 + naga/tests/out/ir/index-by-value.compact.ron | 1 + naga/tests/out/ir/index-by-value.ron | 1 + ...ides-atomicCompareExchangeWeak.compact.ron | 1 + .../overrides-atomicCompareExchangeWeak.ron | 1 + .../out/ir/overrides-ray-query.compact.ron | 1 + naga/tests/out/ir/overrides-ray-query.ron | 1 + naga/tests/out/ir/overrides.compact.ron | 1 + naga/tests/out/ir/overrides.ron | 1 + naga/tests/out/ir/shadow.compact.ron | 1 + naga/tests/out/ir/shadow.ron | 1 + naga/tests/out/ir/spec-constants.compact.ron | 1 + naga/tests/out/ir/spec-constants.ron | 1 + 25 files changed, 141 insertions(+), 4 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 0005cbcb0e..cf7b84a342 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -25,6 +25,8 @@ pub enum PipelineConstantError { ConstantEvaluatorError(#[from] ConstantEvaluatorError), #[error(transparent)] ValidationError(#[from] WithSpan), + #[error("workgroup_size was overridden to a negative value")] + NegativeWorkgroupSize, } /// Replace all overrides in `module` with constants. @@ -190,6 +192,7 @@ pub fn process_overrides<'a>( let mut entry_points = mem::take(&mut module.entry_points); for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut ep.function)?; + process_workgroup_size_override(&mut module, &override_map, ep)?; } module.entry_points = entry_points; @@ -202,6 +205,52 @@ pub fn process_overrides<'a>( Ok((Cow::Owned(module), Cow::Owned(module_info))) } +fn process_workgroup_size_override( + module: &mut Module, + override_map: &HandleVec>, + ep: &mut crate::EntryPoint +) -> Result<(), PipelineConstantError> { + match ep.workgroup_size_overrides { + None => {} + Some(overrides) => { + overrides.iter().enumerate().try_for_each( + |(i, overridden)| -> Result<(), PipelineConstantError> { + match overridden { + None => Ok(()), + Some(h) => { + let c = module.constants[override_map[*h]].init; + let n = &module.global_expressions[c]; + match n { + crate::Expression::Literal(literal) => { + ep.workgroup_size[i] = match literal { + crate::Literal::U32(m) => (*m).into(), + crate::Literal::I32(m) => { + if *m < 0 { + Err(PipelineConstantError::NegativeWorkgroupSize)?; + unreachable!(); + } else { + *m as u32 + } + } + _ => { + unreachable!(); + } + }; + } + _ => { + unreachable!(); + } + } + Ok(()) + } + } + } + )?; + } + } + Ok(()) +} + /// Add a [`Constant`] to `module` for the override `old_h`. /// /// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. diff --git a/naga/src/front/glsl/functions.rs b/naga/src/front/glsl/functions.rs index 658632e872..394be22eaa 100644 --- a/naga/src/front/glsl/functions.rs +++ b/naga/src/front/glsl/functions.rs @@ -1366,6 +1366,7 @@ impl Frontend { early_depth_test: Some(crate::EarlyDepthTest { conservative: None }) .filter(|_| self.meta.early_fragment_tests), workgroup_size: self.meta.workgroup_size, + workgroup_size_overrides: None, function: Function { arguments, expressions, diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 7122e44837..271b96926b 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -569,6 +569,7 @@ impl> super::Frontend { stage: ep.stage, early_depth_test: ep.early_depth_test, workgroup_size: ep.workgroup_size, + workgroup_size_overrides: None, function, }); diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 5173e73d79..8401df4ecd 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1311,24 +1311,55 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .collect(); if let Some(ref entry) = f.entry_point { - let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size { + let workgroup_size_info = if let Some(workgroup_size) = entry.workgroup_size { // TODO: replace with try_map once stabilized let mut workgroup_size_out = [1; 3]; + let mut workgroup_size_overrides_out = [None; 3]; for (i, size) in workgroup_size.into_iter().enumerate() { if let Some(size_expr) = size { - workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0; + match self.const_u32(size_expr, &mut ctx.as_const()) { + Ok(value) => { + workgroup_size_out[i] = value.0; + } + err => { + if let Err(Error::ConstantEvaluatorError(ref ty, _)) = err { + match **ty { + crate::proc::ConstantEvaluatorError::OverrideExpr => { + workgroup_size_overrides_out[i] = Some( + self.workgroup_size_override( + size_expr, + &mut ctx.as_override(), + i, + )? + ); + } + _ => { + err?; + } + } + } else { + err?; + } + } + } } } - workgroup_size_out + if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { + (workgroup_size_out, None) + } else { + (workgroup_size_out, Some(workgroup_size_overrides_out)) + } } else { - [0; 3] + ([0; 3], None) }; + let (workgroup_size, workgroup_size_overrides) = workgroup_size_info; ctx.module.entry_points.push(crate::EntryPoint { name: f.name.name.to_string(), stage: entry.stage, early_depth_test: entry.early_depth_test, workgroup_size, + workgroup_size_overrides, function, }); Ok(LoweredGlobalDecl::EntryPoint) @@ -1338,6 +1369,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } + fn workgroup_size_override( + &mut self, + size_expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + i: usize, + ) -> Result, Error<'source>> { + let span = ctx.ast_expressions.get_span(size_expr); + let expr = self.expression(size_expr, ctx)?; + let ty = ctx.register_type(expr)?; + match ctx.module.types[ty].inner.scalar_kind().ok_or(0) { + Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok({ + ctx.module.overrides.append( + crate::Override { + name: Some(format!("__workgroup_size_{}", i)), + id: None, + ty, + init: Some(expr), + }, + span, + ) + }), + _ => { + Err(Error::ExpectedConstExprConcreteIntegerScalar(span)) + } + } + } + fn block( &mut self, b: &ast::Block<'source>, diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 4afbfff9d7..1c1929efa2 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -2186,6 +2186,8 @@ pub struct EntryPoint { pub early_depth_test: Option, /// Workgroup size for compute stages pub workgroup_size: [u32; 3], + /// Override expressions for workgroup size + pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, } diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 974080e998..e314078c2b 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -1854,6 +1854,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_vert"), arguments: [ @@ -2156,6 +2157,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_frag"), arguments: [], @@ -2348,6 +2350,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_through_ptr"), arguments: [], @@ -2430,6 +2433,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_to_ptr_components"), arguments: [], diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 974080e998..e314078c2b 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -1854,6 +1854,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_vert"), arguments: [ @@ -2156,6 +2157,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_frag"), arguments: [], @@ -2348,6 +2350,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_through_ptr"), arguments: [], @@ -2430,6 +2433,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_to_ptr_components"), arguments: [], diff --git a/naga/tests/out/ir/atomic_i_increment.compact.ron b/naga/tests/out/ir/atomic_i_increment.compact.ron index 7d024f4e81..5bb6820258 100644 --- a/naga/tests/out/ir/atomic_i_increment.compact.ron +++ b/naga/tests/out/ir/atomic_i_increment.compact.ron @@ -263,6 +263,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("stage::test_atomic_i_increment_wrap"), arguments: [], diff --git a/naga/tests/out/ir/atomic_i_increment.ron b/naga/tests/out/ir/atomic_i_increment.ron index aab4c07206..ae14821330 100644 --- a/naga/tests/out/ir/atomic_i_increment.ron +++ b/naga/tests/out/ir/atomic_i_increment.ron @@ -288,6 +288,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("stage::test_atomic_i_increment_wrap"), arguments: [], diff --git a/naga/tests/out/ir/collatz.compact.ron b/naga/tests/out/ir/collatz.compact.ron index 48ce8e76bc..6a7aebe544 100644 --- a/naga/tests/out/ir/collatz.compact.ron +++ b/naga/tests/out/ir/collatz.compact.ron @@ -257,6 +257,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [ diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index 48ce8e76bc..6a7aebe544 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -257,6 +257,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [ diff --git a/naga/tests/out/ir/fetch_depth.compact.ron b/naga/tests/out/ir/fetch_depth.compact.ron index 0d998e205c..f10ccb94f7 100644 --- a/naga/tests/out/ir/fetch_depth.compact.ron +++ b/naga/tests/out/ir/fetch_depth.compact.ron @@ -176,6 +176,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("cull::fetch_depth_wrap"), arguments: [], diff --git a/naga/tests/out/ir/fetch_depth.ron b/naga/tests/out/ir/fetch_depth.ron index c66b7eb065..d25e046d57 100644 --- a/naga/tests/out/ir/fetch_depth.ron +++ b/naga/tests/out/ir/fetch_depth.ron @@ -246,6 +246,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("cull::fetch_depth_wrap"), arguments: [], diff --git a/naga/tests/out/ir/index-by-value.compact.ron b/naga/tests/out/ir/index-by-value.compact.ron index f0ea76f496..93a9821426 100644 --- a/naga/tests/out/ir/index-by-value.compact.ron +++ b/naga/tests/out/ir/index-by-value.compact.ron @@ -300,6 +300,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("index_let_array_1d"), arguments: [ diff --git a/naga/tests/out/ir/index-by-value.ron b/naga/tests/out/ir/index-by-value.ron index f0ea76f496..93a9821426 100644 --- a/naga/tests/out/ir/index-by-value.ron +++ b/naga/tests/out/ir/index-by-value.ron @@ -300,6 +300,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("index_let_array_1d"), arguments: [ diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron index e762de0385..56be2f8ab6 100644 --- a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron @@ -85,6 +85,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("f"), arguments: [], diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron index e762de0385..56be2f8ab6 100644 --- a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron @@ -85,6 +85,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("f"), arguments: [], diff --git a/naga/tests/out/ir/overrides-ray-query.compact.ron b/naga/tests/out/ir/overrides-ray-query.compact.ron index f7d05aa92f..10cad83538 100644 --- a/naga/tests/out/ir/overrides-ray-query.compact.ron +++ b/naga/tests/out/ir/overrides-ray-query.compact.ron @@ -111,6 +111,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/overrides-ray-query.ron b/naga/tests/out/ir/overrides-ray-query.ron index f7d05aa92f..10cad83538 100644 --- a/naga/tests/out/ir/overrides-ray-query.ron +++ b/naga/tests/out/ir/overrides-ray-query.ron @@ -111,6 +111,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index d2df01c0db..d99beb19c6 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -108,6 +108,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index d2df01c0db..d99beb19c6 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -108,6 +108,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/shadow.compact.ron b/naga/tests/out/ir/shadow.compact.ron index 39a25fd10b..24b4674515 100644 --- a/naga/tests/out/ir/shadow.compact.ron +++ b/naga/tests/out/ir/shadow.compact.ron @@ -958,6 +958,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("fs_main_wrap"), arguments: [ diff --git a/naga/tests/out/ir/shadow.ron b/naga/tests/out/ir/shadow.ron index 196536d56b..386b9d36b0 100644 --- a/naga/tests/out/ir/shadow.ron +++ b/naga/tests/out/ir/shadow.ron @@ -1236,6 +1236,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("fs_main_wrap"), arguments: [ diff --git a/naga/tests/out/ir/spec-constants.compact.ron b/naga/tests/out/ir/spec-constants.compact.ron index 9ea75cd468..cde3117225 100644 --- a/naga/tests/out/ir/spec-constants.compact.ron +++ b/naga/tests/out/ir/spec-constants.compact.ron @@ -495,6 +495,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("main_wrap"), arguments: [ diff --git a/naga/tests/out/ir/spec-constants.ron b/naga/tests/out/ir/spec-constants.ron index 5d48e94efc..fa4139a1da 100644 --- a/naga/tests/out/ir/spec-constants.ron +++ b/naga/tests/out/ir/spec-constants.ron @@ -601,6 +601,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("main_wrap"), arguments: [ From 6474295a5542e0da6415bb8fc88ebdd6c5e6c94a Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sat, 30 Nov 2024 14:35:43 -0800 Subject: [PATCH 2/7] linting --- naga/src/back/pipeline_constants.rs | 16 ++++++++-------- naga/src/front/wgsl/lower/mod.rs | 11 ++++------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index cf7b84a342..1a2dc17023 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -208,28 +208,28 @@ pub fn process_overrides<'a>( fn process_workgroup_size_override( module: &mut Module, override_map: &HandleVec>, - ep: &mut crate::EntryPoint + ep: &mut crate::EntryPoint, ) -> Result<(), PipelineConstantError> { match ep.workgroup_size_overrides { None => {} Some(overrides) => { overrides.iter().enumerate().try_for_each( |(i, overridden)| -> Result<(), PipelineConstantError> { - match overridden { + match *overridden { None => Ok(()), Some(h) => { - let c = module.constants[override_map[*h]].init; + let c = module.constants[override_map[h]].init; let n = &module.global_expressions[c]; - match n { + match *n { crate::Expression::Literal(literal) => { ep.workgroup_size[i] = match literal { - crate::Literal::U32(m) => (*m).into(), + crate::Literal::U32(m) => m, crate::Literal::I32(m) => { - if *m < 0 { + if m < 0 { Err(PipelineConstantError::NegativeWorkgroupSize)?; unreachable!(); } else { - *m as u32 + m as u32 } } _ => { @@ -244,7 +244,7 @@ fn process_workgroup_size_override( Ok(()) } } - } + }, )?; } } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 8401df4ecd..7de6fc945b 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1325,13 +1325,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { if let Err(Error::ConstantEvaluatorError(ref ty, _)) = err { match **ty { crate::proc::ConstantEvaluatorError::OverrideExpr => { - workgroup_size_overrides_out[i] = Some( - self.workgroup_size_override( + workgroup_size_overrides_out[i] = + Some(self.workgroup_size_override( size_expr, &mut ctx.as_override(), i, - )? - ); + )?); } _ => { err?; @@ -1390,9 +1389,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, ) }), - _ => { - Err(Error::ExpectedConstExprConcreteIntegerScalar(span)) - } + _ => Err(Error::ExpectedConstExprConcreteIntegerScalar(span)), } } From 13f98184f1298c3d4967aba5b374b09b13abf3f1 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sat, 30 Nov 2024 14:45:22 -0800 Subject: [PATCH 3/7] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f60dff7eb1..ca7131118f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148] - Implement `quantizeToF16()` for WGSL frontend, and WGSL, SPIR-V, HLSL, MSL, and GLSL backends. By @jamienicol in [#6519](https://github.com/gfx-rs/wgpu/pull/6519). - Add support for GLSL `usampler*` and `isampler*`. By @DavidPeicho in [#6513](https://github.com/gfx-rs/wgpu/pull/6513). - Expose Ray Query flags as constants in WGSL. Implement candidate intersections. By @kvark in [#5429](https://github.com/gfx-rs/wgpu/pull/5429) +- Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635). #### General From 59861d47d8e61a90b689338d9088b732cadb1f06 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 1 Dec 2024 13:06:22 -0800 Subject: [PATCH 4/7] integration test --- tests/tests/shader/mod.rs | 1 + .../tests/shader/workgroup_size_overrides.rs | 103 ++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 tests/tests/shader/workgroup_size_overrides.rs diff --git a/tests/tests/shader/mod.rs b/tests/tests/shader/mod.rs index 7d6ed7aaaa..f05fbac25c 100644 --- a/tests/tests/shader/mod.rs +++ b/tests/tests/shader/mod.rs @@ -19,6 +19,7 @@ pub mod compilation_messages; pub mod data_builtins; pub mod numeric_builtins; pub mod struct_layout; +pub mod workgroup_size_overrides; pub mod zero_init_workgroup_mem; #[derive(Clone, Copy, PartialEq)] diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs new file mode 100644 index 0000000000..90f319847d --- /dev/null +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -0,0 +1,103 @@ +use std::mem::size_of_val; +use wgpu::util::DeviceExt; +use wgpu::{BufferDescriptor, BufferUsages, Maintain, MapMode}; +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; + +const SHADER: &str = r#" + override n = 3; + + @group(0) @binding(0) + var output: array; + + @compute @workgroup_size(n - 2) + fn main(@builtin(local_invocation_index) lii: u32) { + output[lii] = lii + 2; + } +"#; + +#[gpu_test] +static WORKGROUP_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters(TestParameters::default().limits(wgpu::Limits::default())) + .run_async(move |ctx| async move { + workgroup_size_overrides(&ctx, 0, &[2, 0, 0]).await; + workgroup_size_overrides(&ctx, 4, &[2, 3, 0]).await; + // Expected to fail during pipeline creation: + //workgroup_size_overrides(&ctx, 1, &[0, 0, 0]).await; + }); + +async fn workgroup_size_overrides(ctx: &TestingContext, n: u32, out: &[u32]) { + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(SHADER)), + }); + let pipeline_options = wgpu::PipelineCompilationOptions { + constants: &[("n".to_owned(), n.into())].into(), + ..Default::default() + }; + let compute_pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: Some("main"), + compilation_options: if n == 0 { + wgpu::PipelineCompilationOptions::default() + } else { + pipeline_options + }, + cache: None, + }); + let init: &[u32] = &[0, 0, 0]; + let init_size: u64 = size_of_val(init).try_into().unwrap(); + let buffer = DeviceExt::create_buffer_init( + &ctx.device, + &wgpu::util::BufferInitDescriptor { + label: None, + contents: bytemuck::cast_slice(init), + usage: wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::COPY_SRC, + }, + ); + let mapping_buffer = ctx.device.create_buffer(&BufferDescriptor { + label: Some("mapping buffer"), + size: init_size, + usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&compute_pipeline); + let bind_group_layout = compute_pipeline.get_bind_group_layout(0); + let bind_group_entries = [wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }]; + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &bind_group_entries, + }); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups(1, 1, 1); + } + encoder.copy_buffer_to_buffer(&buffer, 0, &mapping_buffer, 0, init_size); + ctx.queue.submit(Some(encoder.finish())); + + mapping_buffer.slice(..).map_async(MapMode::Read, |_| ()); + ctx.async_poll(Maintain::wait()).await.panic_on_timeout(); + + let mapped = mapping_buffer.slice(..).get_mapped_range(); + + let typed: &[u32] = bytemuck::cast_slice(&mapped); + assert_eq!(typed, out); +} From 5fd355fc900faa9b6362c64de828e2201eace48c Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 1 Dec 2024 14:49:10 -0800 Subject: [PATCH 5/7] separate use_override from value and add should_fail --- .../tests/shader/workgroup_size_overrides.rs | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs index 90f319847d..424e1201e0 100644 --- a/tests/tests/shader/workgroup_size_overrides.rs +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -1,7 +1,7 @@ use std::mem::size_of_val; use wgpu::util::DeviceExt; use wgpu::{BufferDescriptor, BufferUsages, Maintain, MapMode}; -use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; +use wgpu_test::{fail_if, gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; const SHADER: &str = r#" override n = 3; @@ -19,13 +19,18 @@ const SHADER: &str = r#" static WORKGROUP_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::new() .parameters(TestParameters::default().limits(wgpu::Limits::default())) .run_async(move |ctx| async move { - workgroup_size_overrides(&ctx, 0, &[2, 0, 0]).await; - workgroup_size_overrides(&ctx, 4, &[2, 3, 0]).await; - // Expected to fail during pipeline creation: - //workgroup_size_overrides(&ctx, 1, &[0, 0, 0]).await; + workgroup_size_overrides(&ctx, false, 0, &[2, 0, 0], false).await; + workgroup_size_overrides(&ctx, true, 4, &[2, 3, 0], false).await; + workgroup_size_overrides(&ctx, true, 1, &[0, 0, 0], true).await; }); -async fn workgroup_size_overrides(ctx: &TestingContext, n: u32, out: &[u32]) { +async fn workgroup_size_overrides( + ctx: &TestingContext, + use_override: bool, + n: u32, + out: &[u32], + should_fail: bool, +) { let module = ctx .device .create_shader_module(wgpu::ShaderModuleDescriptor { @@ -36,20 +41,28 @@ async fn workgroup_size_overrides(ctx: &TestingContext, n: u32, out: &[u32]) { constants: &[("n".to_owned(), n.into())].into(), ..Default::default() }; - let compute_pipeline = ctx - .device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: Some("main"), - compilation_options: if n == 0 { - wgpu::PipelineCompilationOptions::default() - } else { - pipeline_options - }, - cache: None, - }); + let compute_pipeline = fail_if( + &ctx.device, + should_fail, + || { + ctx.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: Some("main"), + compilation_options: if use_override { + pipeline_options + } else { + wgpu::PipelineCompilationOptions::default() + }, + cache: None, + }) + }, + None + ); + if should_fail { + return; + } let init: &[u32] = &[0, 0, 0]; let init_size: u64 = size_of_val(init).try_into().unwrap(); let buffer = DeviceExt::create_buffer_init( From cfbda3197a2abff1d94ba20594f1b57effb0e95b Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 1 Dec 2024 14:50:25 -0800 Subject: [PATCH 6/7] linting --- .../tests/shader/workgroup_size_overrides.rs | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs index 424e1201e0..038d95e02d 100644 --- a/tests/tests/shader/workgroup_size_overrides.rs +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -20,8 +20,8 @@ static WORKGROUP_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::ne .parameters(TestParameters::default().limits(wgpu::Limits::default())) .run_async(move |ctx| async move { workgroup_size_overrides(&ctx, false, 0, &[2, 0, 0], false).await; - workgroup_size_overrides(&ctx, true, 4, &[2, 3, 0], false).await; - workgroup_size_overrides(&ctx, true, 1, &[0, 0, 0], true).await; + workgroup_size_overrides(&ctx, true, 4, &[2, 3, 0], false).await; + workgroup_size_overrides(&ctx, true, 1, &[0, 0, 0], true).await; }); async fn workgroup_size_overrides( @@ -45,20 +45,21 @@ async fn workgroup_size_overrides( &ctx.device, should_fail, || { - ctx.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: Some("main"), - compilation_options: if use_override { - pipeline_options - } else { - wgpu::PipelineCompilationOptions::default() - }, - cache: None, - }) + ctx.device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: Some("main"), + compilation_options: if use_override { + pipeline_options + } else { + wgpu::PipelineCompilationOptions::default() + }, + cache: None, + }) }, - None + None, ); if should_fail { return; From 79c1a4917310ed84cecd28308033e1dd0ab56c16 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 1 Dec 2024 15:27:22 -0800 Subject: [PATCH 7/7] n as an option instead of an option bool to use n --- tests/tests/shader/workgroup_size_overrides.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs index 038d95e02d..458011079a 100644 --- a/tests/tests/shader/workgroup_size_overrides.rs +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -19,15 +19,14 @@ const SHADER: &str = r#" static WORKGROUP_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::new() .parameters(TestParameters::default().limits(wgpu::Limits::default())) .run_async(move |ctx| async move { - workgroup_size_overrides(&ctx, false, 0, &[2, 0, 0], false).await; - workgroup_size_overrides(&ctx, true, 4, &[2, 3, 0], false).await; - workgroup_size_overrides(&ctx, true, 1, &[0, 0, 0], true).await; + workgroup_size_overrides(&ctx, None, &[2, 0, 0], false).await; + workgroup_size_overrides(&ctx, Some(4), &[2, 3, 0], false).await; + workgroup_size_overrides(&ctx, Some(1), &[0, 0, 0], true).await; }); async fn workgroup_size_overrides( ctx: &TestingContext, - use_override: bool, - n: u32, + n: Option, out: &[u32], should_fail: bool, ) { @@ -38,7 +37,7 @@ async fn workgroup_size_overrides( source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(SHADER)), }); let pipeline_options = wgpu::PipelineCompilationOptions { - constants: &[("n".to_owned(), n.into())].into(), + constants: &[("n".to_owned(), n.unwrap_or(0).into())].into(), ..Default::default() }; let compute_pipeline = fail_if( @@ -51,7 +50,7 @@ async fn workgroup_size_overrides( layout: None, module: &module, entry_point: Some("main"), - compilation_options: if use_override { + compilation_options: if n.is_some() { pipeline_options } else { wgpu::PipelineCompilationOptions::default()